1 Commits
loom ... test

Author SHA1 Message Date
Evan Jarrett
dd5d2aab55 test tag push 2025-11-03 13:45:59 -06:00
30 changed files with 879 additions and 743 deletions

View File

@@ -1,23 +0,0 @@
when:
- event: ["push"]
branch: ["*"]
- event: ["pull_request"]
branch: ["main"]
engine: kubernetes
image: golang:1.24-bookworm
architecture: amd64
steps:
- name: Download and Generate
environment:
CGO_ENABLED: 1
command: |
go mod download
go generate ./...
- name: Run Tests
environment:
CGO_ENABLED: 1
command: |
go test -cover ./...

View File

@@ -1,23 +0,0 @@
when:
- event: ["push"]
branch: ["*"]
- event: ["pull_request"]
branch: ["main"]
engine: kubernetes
image: golang:1.24-bookworm
architecture: arm64
steps:
- name: Download and Generate
environment:
CGO_ENABLED: 1
command: |
go mod download
go generate ./...
- name: Run Tests
environment:
CGO_ENABLED: 1
command: |
go test -cover ./...

View File

@@ -1,13 +1,14 @@
# Tangled Workflow: Release Credential Helper # Tangled Workflow: Release Credential Helper to Tangled.org
# #
# This workflow builds cross-platform binaries for the credential helper. # This workflow builds the docker-credential-atcr binary and publishes it
# Creates tarballs for curl/bash installation and provides instructions # to Tangled.org for distribution via Homebrew.
# for updating the Homebrew formula.
# #
# Triggers on version tags (v*) pushed to the repository. # Current limitation: Tangled doesn't support triggering on tags yet,
# so this triggers on push to main. Manually verify you've tagged the
# release before pushing.
when: when:
- event: ["manual"] - event: ["push"]
tag: ["v*"] tag: ["v*"]
engine: "nixery" engine: "nixery"
@@ -15,36 +16,20 @@ engine: "nixery"
dependencies: dependencies:
nixpkgs: nixpkgs:
- go_1_24 # Go 1.24+ for building - go_1_24 # Go 1.24+ for building
- git # For finding tags
- goreleaser # For building multi-platform binaries - goreleaser # For building multi-platform binaries
- curl # Required by go generate for downloading vendor assets # - goat # TODO: Add goat CLI for uploading to Tangled (if available in nixpkgs)
- gnugrep # Required for tag detection
- gnutar # Required for creating tarballs
- gzip # Required for compressing tarballs
- coreutils # Required for sha256sum
environment: environment:
CGO_ENABLED: "0" # Build static binaries CGO_ENABLED: "0" # Build static binaries
steps: steps:
- name: Get tag for current commit - name: Find latest git tag
command: | command: |
# Fetch tags (shallow clone doesn't include them by default) # Get the most recent version tag
git fetch --tags LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.1")
echo "Latest tag: $LATEST_TAG"
# Find the tag that points to the current commit echo "$LATEST_TAG" > .version
TAG=$(git tag --points-at HEAD | grep -E '^v[0-9]' | head -n1)
if [ -z "$TAG" ]; then
echo "Error: No version tag found for current commit"
echo "Available tags:"
git tag
echo "Current commit:"
git rev-parse HEAD
exit 1
fi
echo "Building version: $TAG"
echo "$TAG" > .version
# Also get the commit hash for reference # Also get the commit hash for reference
COMMIT_HASH=$(git rev-parse HEAD) COMMIT_HASH=$(git rev-parse HEAD)
@@ -52,20 +37,17 @@ steps:
- name: Build binaries with GoReleaser - name: Build binaries with GoReleaser
command: | command: |
# Read version from previous step
VERSION=$(cat .version) VERSION=$(cat .version)
export VERSION export VERSION
# Build for all platforms using GoReleaser # Build for all platforms using GoReleaser
# This creates artifacts in dist/ directory
goreleaser build --clean --snapshot --config .goreleaser.yaml goreleaser build --clean --snapshot --config .goreleaser.yaml
# List what was built # List what was built
echo "Built artifacts:" echo "Built artifacts:"
if [ -d "dist" ]; then ls -lh dist/
ls -lh dist/
else
echo "Error: dist/ directory was not created by GoReleaser"
exit 1
fi
- name: Package artifacts - name: Package artifacts
command: | command: |
@@ -74,82 +56,82 @@ steps:
cd dist cd dist
# Create tarballs for each platform # Create tarballs for each platform (GoReleaser might already do this)
# GoReleaser creates directories like: credential-helper_{os}_{arch}_v{goversion}
# Darwin x86_64 # Darwin x86_64
if [ -d "credential-helper_darwin_amd64_v1" ]; then if [ -d "docker-credential-atcr_darwin_amd64_v1" ]; then
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz" \ tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz" \
-C credential-helper_darwin_amd64_v1 docker-credential-atcr -C docker-credential-atcr_darwin_amd64_v1 docker-credential-atcr
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz"
fi fi
# Darwin arm64 # Darwin arm64
for dir in credential-helper_darwin_arm64*; do if [ -d "docker-credential-atcr_darwin_arm64" ]; then
if [ -d "$dir" ]; then tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz" \
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz" \ -C docker-credential-atcr_darwin_arm64 docker-credential-atcr
-C "$dir" docker-credential-atcr fi
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz"
break
fi
done
# Linux x86_64 # Linux x86_64
if [ -d "credential-helper_linux_amd64_v1" ]; then if [ -d "docker-credential-atcr_linux_amd64_v1" ]; then
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz" \ tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz" \
-C credential-helper_linux_amd64_v1 docker-credential-atcr -C docker-credential-atcr_linux_amd64_v1 docker-credential-atcr
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz"
fi fi
# Linux arm64 # Linux arm64
for dir in credential-helper_linux_arm64*; do if [ -d "docker-credential-atcr_linux_arm64" ]; then
if [ -d "$dir" ]; then tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz" \
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz" \ -C docker-credential-atcr_linux_arm64 docker-credential-atcr
-C "$dir" docker-credential-atcr fi
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz"
break
fi
done
echo "Created tarballs:"
ls -lh *.tar.gz
- name: Upload to Tangled.org
command: |
VERSION=$(cat .version)
VERSION_NO_V=${VERSION#v}
# TODO: Authenticate with goat CLI
# You'll need to set up credentials/tokens for goat
# Example (adjust based on goat's actual auth mechanism):
# goat login --pds https://your-pds.example.com --handle your.handle
# TODO: Upload each artifact to Tangled.org
# This creates sh.tangled.repo.artifact records in your ATProto PDS
# Adjust these commands based on scripts/publish-artifact.sh pattern
# Example structure (you'll need to fill in actual goat commands):
# for artifact in dist/*.tar.gz; do
# echo "Uploading $artifact..."
# goat upload \
# --repo "at-container-registry" \
# --tag "$VERSION" \
# --file "$artifact"
# done
echo "TODO: Implement goat upload commands"
echo "See scripts/publish-artifact.sh for reference"
echo "" echo ""
echo "Tarballs ready:" echo "After uploading, you'll receive a TAG_HASH from Tangled."
ls -lh *.tar.gz 2>/dev/null || echo "Warning: No tarballs created" echo "Update Formula/docker-credential-atcr.rb with:"
echo " VERSION = \"$VERSION_NO_V\""
echo " TAG_HASH = \"<hash-from-tangled>\""
echo ""
echo "Then run: scripts/update-homebrew-formula.sh $VERSION_NO_V <tag-hash>"
- name: Generate checksums - name: Generate checksums for verification
command: | command: |
VERSION=$(cat .version) VERSION=$(cat .version)
VERSION_NO_V=${VERSION#v} VERSION_NO_V=${VERSION#v}
cd dist cd dist
echo "" echo "SHA256 checksums for Homebrew formula:"
echo "==========================================" echo "======================================="
echo "SHA256 Checksums"
echo "=========================================="
echo ""
# Generate checksums file for file in docker-credential-atcr_${VERSION_NO_V}_*.tar.gz; do
sha256sum docker-credential-atcr_${VERSION_NO_V}_*.tar.gz 2>/dev/null | tee checksums.txt || echo "No checksums generated" if [ -f "$file" ]; then
sha256sum "$file"
- name: Next steps fi
command: | done
VERSION=$(cat .version)
echo "" echo ""
echo "==========================================" echo "Copy these checksums to Formula/docker-credential-atcr.rb"
echo "Release $VERSION is ready!"
echo "=========================================="
echo ""
echo "Distribution tarballs are in: dist/"
echo ""
echo "Next steps:"
echo ""
echo "1. Upload tarballs to your hosting/CDN (or GitHub releases)"
echo ""
echo "2. For Homebrew users, update the formula:"
echo " ./scripts/update-homebrew-formula.sh $VERSION"
echo " # Then update Formula/docker-credential-atcr.rb and push to homebrew-tap"
echo ""
echo "3. For curl/bash installation, users can download directly:"
echo " curl -L <your-cdn>/docker-credential-atcr_<version>_<os>_<arch>.tar.gz | tar xz"
echo " sudo mv docker-credential-atcr /usr/local/bin/"

View File

@@ -2,55 +2,38 @@
# Triggers on version tags and builds cross-platform binaries using buildah # Triggers on version tags and builds cross-platform binaries using buildah
when: when:
- event: ["push"] - event: ["manual"]
tag: ["v*"] # TODO: Trigger only on version tags (v1.0.0, v2.1.3, etc.)
branch: ["main"]
engine: "buildah" engine: "nixery"
dependencies:
nixpkgs:
- buildah
- chroot
environment: environment:
IMAGE_REGISTRY: atcr.io IMAGE_REGISTRY: atcr.io
IMAGE_USER: evan.jarrett.net IMAGE_USER: evan.jarrett.net
steps: steps:
- name: Get tag for current commit - name: Setup build environment
command: | command: |
#test if ! grep -q "^root:" /etc/passwd 2>/dev/null; then
# Fetch tags (shallow clone doesn't include them by default) echo "root:x:0:0:root:/root:/bin/sh" >> /etc/passwd
git fetch --tags
# Find the tag that points to the current commit
TAG=$(git tag --points-at HEAD | grep -E '^v[0-9]' | head -n1)
if [ -z "$TAG" ]; then
echo "Error: No version tag found for current commit"
echo "Available tags:"
git tag
echo "Current commit:"
git rev-parse HEAD
exit 1
fi fi
echo "Building version: $TAG" - name: Login to registry
echo "$TAG" > .version
- name: Setup registry credentials
command: | command: |
mkdir -p ~/.docker echo "${APP_PASSWORD}" | buildah login \
cat > ~/.docker/config.json <<EOF --storage-driver vfs \
{ -u "${IMAGE_USER}" \
"auths": { --password-stdin \
"${IMAGE_REGISTRY}": { ${IMAGE_REGISTRY}
"auth": "$(echo -n "${IMAGE_USER}:${APP_PASSWORD}" | base64)"
}
}
}
EOF
chmod 600 ~/.docker/config.json
- name: Build and push AppView image - name: Build and push AppView image
command: | command: |
TAG=$(cat .version)
buildah bud \ buildah bud \
--storage-driver vfs \ --storage-driver vfs \
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG} \ --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG} \
@@ -58,18 +41,12 @@ steps:
--file ./Dockerfile.appview \ --file ./Dockerfile.appview \
. .
buildah push \
--storage-driver vfs \
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG}
buildah push \ buildah push \
--storage-driver vfs \ --storage-driver vfs \
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:latest ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:latest
- name: Build and push Hold image - name: Build and push Hold image
command: | command: |
TAG=$(cat .version)
buildah bud \ buildah bud \
--storage-driver vfs \ --storage-driver vfs \
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG} \ --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG} \
@@ -77,10 +54,6 @@ steps:
--file ./Dockerfile.hold \ --file ./Dockerfile.hold \
. .
buildah push \
--storage-driver vfs \
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG}
buildah push \ buildah push \
--storage-driver vfs \ --storage-driver vfs \
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:latest ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:latest

View File

@@ -1,6 +1,8 @@
when: when:
- event: ["push"] - event: ["push"]
branch: ["main", "test"] branch: ["main"]
- event: ["pull_request"]
branch: ["main"]
engine: "nixery" engine: "nixery"

110
CLAUDE.md
View File

@@ -206,62 +206,9 @@ ATCR uses middleware and routing to handle requests:
- Implements `distribution.Repository` - Implements `distribution.Repository`
- Returns custom `Manifests()` and `Blobs()` implementations - Returns custom `Manifests()` and `Blobs()` implementations
- Routes manifests to ATProto, blobs to S3 or BYOS - Routes manifests to ATProto, blobs to S3 or BYOS
- **IMPORTANT**: RoutingRepository is created fresh on EVERY request (no caching)
- Each Docker layer upload is a separate HTTP request (possibly different process)
- OAuth sessions can be refreshed/invalidated between requests
- The OAuth refresher already caches sessions efficiently (in-memory + DB)
- Previous caching of repositories with stale ATProtoClient caused "invalid refresh token" errors
### Authentication Architecture ### Authentication Architecture
#### Token Types and Flows
ATCR uses three distinct token types in its authentication flow:
**1. OAuth Tokens (Access + Refresh)**
- **Issued by:** User's PDS via OAuth flow
- **Stored in:** AppView database (`oauth_sessions` table)
- **Cached in:** Refresher's in-memory map (per-DID)
- **Used for:** AppView → User's PDS communication (write manifests, read profiles)
- **Managed by:** Indigo library with DPoP (automatic refresh)
- **Lifetime:** Access ~2 hours, Refresh ~90 days (PDS controlled)
**2. Registry JWTs**
- **Issued by:** AppView after OAuth login
- **Stored in:** Docker credential helper (`~/.atcr/credential-helper-token.json`)
- **Used for:** Docker client → AppView authentication
- **Lifetime:** 15 minutes (configurable via `ATCR_TOKEN_EXPIRATION`)
- **Format:** JWT with DID claim
**3. Service Tokens**
- **Issued by:** User's PDS via `com.atproto.server.getServiceAuth`
- **Stored in:** AppView memory (in-memory cache with ~50s TTL)
- **Used for:** AppView → Hold service authentication (acting on behalf of user)
- **Lifetime:** 60 seconds (PDS controlled), cached for 50s
- **Required:** OAuth session to obtain (catch-22 solved by Refresher)
**Token Flow Diagram:**
```
┌─────────────┐ ┌──────────────┐
│ Docker │ ─── Registry JWT ──────────────→ │ AppView │
│ Client │ │ │
└─────────────┘ └──────┬───────┘
│ OAuth tokens
│ (access + refresh)
┌──────────────┐
│ User's PDS │
└──────┬───────┘
│ Service token
│ (via getServiceAuth)
┌──────────────┐
│ Hold Service │
└──────────────┘
```
#### ATProto OAuth with DPoP #### ATProto OAuth with DPoP
ATCR implements the full ATProto OAuth specification with mandatory security features: ATCR implements the full ATProto OAuth specification with mandatory security features:
@@ -273,22 +220,13 @@ ATCR implements the full ATProto OAuth specification with mandatory security fea
**Key Components** (`pkg/auth/oauth/`): **Key Components** (`pkg/auth/oauth/`):
1. **Client** (`client.go`) - OAuth client configuration and session management 1. **Client** (`client.go`) - Core OAuth client with encapsulated configuration
- **ClientApp setup:** - Uses indigo's `NewLocalhostConfig()` for localhost (public client)
- `NewClientApp()` - Creates configured `*oauth.ClientApp` (uses indigo directly, no wrapper) - Uses `NewPublicConfig()` for production base (upgraded to confidential if key provided)
- Uses `NewLocalhostConfig()` for localhost (public client) - `RedirectURI()` - returns `baseURL + "/auth/oauth/callback"`
- Uses `NewPublicConfig()` for production (upgraded to confidential with P-256 key) - `GetDefaultScopes()` - returns ATCR registry scopes
- `GetDefaultScopes()` - Returns ATCR-specific OAuth scopes - `GetConfigRef()` - returns mutable config for `SetClientSecret()` calls
- `ScopesMatch()` - Compares scope lists (order-independent) - All OAuth flows (authorization, token exchange, refresh) in one place
- **Session management (Refresher):**
- `NewRefresher()` - Creates session cache manager for AppView
- **Purpose:** In-memory cache for `*oauth.ClientSession` objects (performance optimization)
- **Why needed:** Saves 1-2 DB queries per request (~2ms) with minimal code complexity
- Per-DID locking prevents concurrent database loads
- Calls `ClientApp.ResumeSession()` on cache miss
- Indigo handles token refresh automatically (transparent to ATCR)
- **Performance:** Essential for high-traffic deployments, negligible for low-traffic
- **Architecture:** Single file containing both ClientApp helpers and Refresher (combined from previous two-file structure)
2. **Keys** (`keys.go`) - P-256 key management for confidential clients 2. **Keys** (`keys.go`) - P-256 key management for confidential clients
- `GenerateOrLoadClientKey()` - generates or loads P-256 key from disk - `GenerateOrLoadClientKey()` - generates or loads P-256 key from disk
@@ -297,17 +235,21 @@ ATCR implements the full ATProto OAuth specification with mandatory security fea
- `PrivateKeyToMultibase()` - converts key for `SetClientSecret()` API - `PrivateKeyToMultibase()` - converts key for `SetClientSecret()` API
- **Key type:** P-256 (ES256) for OAuth standard compatibility (not K-256 like PDS keys) - **Key type:** P-256 (ES256) for OAuth standard compatibility (not K-256 like PDS keys)
3. **Storage** - Persists OAuth sessions 3. **Token Storage** (`store.go`) - Persists OAuth sessions for AppView
- `db/oauth_store.go` - SQLite-backed storage for AppView (in UI database) - SQLite-backed storage in UI database (not file-based)
- `store.go` - File-based storage for CLI tools (`~/.atcr/oauth-sessions.json`) - Client uses `~/.atcr/oauth-token.json` (credential helper)
- Implements indigo's `ClientAuthStore` interface
4. **Server** (`server.go`) - OAuth authorization endpoints for AppView 4. **Refresher** (`refresher.go`) - Token refresh manager for AppView
- Caches OAuth sessions with automatic token refresh (handled by indigo library)
- Per-DID locking prevents concurrent refresh races
- Uses Client methods for consistency
5. **Server** (`server.go`) - OAuth authorization endpoints for AppView
- `GET /auth/oauth/authorize` - starts OAuth flow - `GET /auth/oauth/authorize` - starts OAuth flow
- `GET /auth/oauth/callback` - handles OAuth callback - `GET /auth/oauth/callback` - handles OAuth callback
- Uses `ClientApp` methods directly (no wrapper) - Uses Client methods for authorization and token exchange
5. **Interactive Flow** (`interactive.go`) - Reusable OAuth flow for CLI tools 6. **Interactive Flow** (`interactive.go`) - Reusable OAuth flow for CLI tools
- Used by credential helper and hold service registration - Used by credential helper and hold service registration
- Two-phase callback setup ensures PAR metadata availability - Two-phase callback setup ensures PAR metadata availability
@@ -407,13 +349,12 @@ Later (subsequent docker push):
- Implements `distribution.Repository` interface - Implements `distribution.Repository` interface
- Uses RegistryContext to pass DID, PDS endpoint, hold DID, OAuth refresher, etc. - Uses RegistryContext to pass DID, PDS endpoint, hold DID, OAuth refresher, etc.
**Database-based hold DID lookups**: **hold_cache.go**: In-memory hold DID cache
- Queries SQLite `manifests` table for hold DID (indexed, fast) - Caches `(DID, repository) → holdDid` for pull operations
- No in-memory caching needed - database IS the cache - TTL: 10 minutes (covers typical pull operations)
- Persistent across restarts, multi-instance safe - Cleanup: Background goroutine runs every 5 minutes
- Pull operations use hold DID from latest manifest (historical reference) - **NOTE:** Simple in-memory cache for MVP. For production: use Redis or similar
- Push operations use fresh discovery from profile/default - Prevents expensive PDS manifest lookups on every blob request during pull
- Function: `db.GetLatestHoldDIDForRepo(did, repository)` in `pkg/appview/db/queries.go`
**proxy_blob_store.go**: External storage proxy (routes to hold via XRPC) **proxy_blob_store.go**: External storage proxy (routes to hold via XRPC)
- Resolves hold DID → HTTP URL for XRPC requests (did:web resolution) - Resolves hold DID → HTTP URL for XRPC requests (did:web resolution)
@@ -663,8 +604,7 @@ See `.env.hold.example` for all available options. Key environment variables:
**General:** **General:**
- Middleware is in `pkg/appview/middleware/` (auth.go, registry.go) - Middleware is in `pkg/appview/middleware/` (auth.go, registry.go)
- Storage routing is in `pkg/appview/storage/` (routing_repository.go, proxy_blob_store.go) - Storage routing is in `pkg/appview/storage/` (routing_repository.go, proxy_blob_store.go, hold_cache.go)
- Hold DID lookups use database queries (no in-memory caching)
- Storage drivers imported as `_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"` - Storage drivers imported as `_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"`
- Hold service reuses distribution's driver factory for multi-backend support - Hold service reuses distribution's driver factory for multi-backend support

View File

@@ -119,11 +119,10 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
slog.Info("TEST_MODE enabled - will use HTTP for local DID resolution and transition:generic scope") slog.Info("TEST_MODE enabled - will use HTTP for local DID resolution and transition:generic scope")
} }
// Create OAuth client app (automatically configures confidential client for production) // Create OAuth app (automatically configures confidential client for production)
desiredScopes := oauth.GetDefaultScopes(defaultHoldDID) oauthApp, err := oauth.NewApp(baseURL, oauthStore, defaultHoldDID, cfg.Server.OAuthKeyPath, cfg.Server.ClientName)
oauthClientApp, err := oauth.NewClientApp(baseURL, oauthStore, desiredScopes, cfg.Server.OAuthKeyPath, cfg.Server.ClientName)
if err != nil { if err != nil {
return fmt.Errorf("failed to create OAuth client app: %w", err) return fmt.Errorf("failed to create OAuth app: %w", err)
} }
if testMode { if testMode {
slog.Info("Using OAuth scopes with transition:generic (test mode)") slog.Info("Using OAuth scopes with transition:generic (test mode)")
@@ -133,6 +132,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
// Invalidate sessions with mismatched scopes on startup // Invalidate sessions with mismatched scopes on startup
// This ensures all users have the latest required scopes after deployment // This ensures all users have the latest required scopes after deployment
desiredScopes := oauth.GetDefaultScopes(defaultHoldDID)
invalidatedCount, err := oauthStore.InvalidateSessionsWithMismatchedScopes(context.Background(), desiredScopes) invalidatedCount, err := oauthStore.InvalidateSessionsWithMismatchedScopes(context.Background(), desiredScopes)
if err != nil { if err != nil {
slog.Warn("Failed to invalidate sessions with mismatched scopes", "error", err) slog.Warn("Failed to invalidate sessions with mismatched scopes", "error", err)
@@ -141,7 +141,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
} }
// Create oauth token refresher // Create oauth token refresher
refresher := oauth.NewRefresher(oauthClientApp) refresher := oauth.NewRefresher(oauthApp)
// Wire up UI session store to refresher so it can invalidate UI sessions on OAuth failures // Wire up UI session store to refresher so it can invalidate UI sessions on OAuth failures
if uiSessionStore != nil { if uiSessionStore != nil {
@@ -189,7 +189,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
Database: uiDatabase, Database: uiDatabase,
ReadOnlyDB: uiReadOnlyDB, ReadOnlyDB: uiReadOnlyDB,
SessionStore: uiSessionStore, SessionStore: uiSessionStore,
OAuthClientApp: oauthClientApp, OAuthApp: oauthApp,
OAuthStore: oauthStore, OAuthStore: oauthStore,
Refresher: refresher, Refresher: refresher,
BaseURL: baseURL, BaseURL: baseURL,
@@ -202,7 +202,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
} }
// Create OAuth server // Create OAuth server
oauthServer := oauth.NewServer(oauthClientApp) oauthServer := oauth.NewServer(oauthApp)
// Connect server to refresher for cache invalidation // Connect server to refresher for cache invalidation
oauthServer.SetRefresher(refresher) oauthServer.SetRefresher(refresher)
// Connect UI session store for web login // Connect UI session store for web login
@@ -223,7 +223,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
} }
// Resume OAuth session to get authenticated client // Resume OAuth session to get authenticated client
session, err := oauthClientApp.ResumeSession(ctx, didParsed, sessionID) session, err := oauthApp.ResumeSession(ctx, didParsed, sessionID)
if err != nil { if err != nil {
slog.Warn("Failed to resume session", "component", "appview/callback", "did", did, "error", err) slog.Warn("Failed to resume session", "component", "appview/callback", "did", did, "error", err)
// Fallback: update user without avatar // Fallback: update user without avatar
@@ -320,12 +320,10 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
} }
// Register crew regardless of migration (outside the migration block) // Register crew regardless of migration (outside the migration block)
// Run in background to avoid blocking OAuth callback if hold is offline // Run in background to avoid blocking OAuth callback if hold is offline
// Use background context - don't inherit request context which gets canceled on response
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID) slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
go func(client *atproto.Client, refresher *oauth.Refresher, holdDID string) { go func(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
ctx := context.Background()
storage.EnsureCrewMembership(ctx, client, refresher, holdDID) storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
}(client, refresher, holdDID) }(ctx, client, refresher, holdDID)
} }
@@ -385,7 +383,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
// OAuth client metadata endpoint // OAuth client metadata endpoint
mainRouter.Get("/client-metadata.json", func(w http.ResponseWriter, r *http.Request) { mainRouter.Get("/client-metadata.json", func(w http.ResponseWriter, r *http.Request) {
config := oauthClientApp.Config config := oauthApp.GetConfig()
metadata := config.ClientMetadata() metadata := config.ClientMetadata()
// For confidential clients, ensure JWKS is included // For confidential clients, ensure JWKS is included

View File

@@ -211,7 +211,7 @@ These components are essential to registry operation and still need coverage.
OAuth implementation has test files but many functions remain untested. OAuth implementation has test files but many functions remain untested.
#### client.go - Session Management (Refresher) (Partial coverage) #### refresher.go (Partial coverage)
**Well-covered:** **Well-covered:**
- `NewRefresher()` - 100% ✅ - `NewRefresher()` - 100% ✅
@@ -227,8 +227,6 @@ OAuth implementation has test files but many functions remain untested.
- Session retrieval and caching - Session retrieval and caching
- Token refresh flow - Token refresh flow
- Concurrent refresh handling (per-DID locking) - Concurrent refresh handling (per-DID locking)
**Note:** Refresher functionality merged into client.go (previously separate refresher.go file)
- Cache expiration - Cache expiration
- Error handling for failed refreshes - Error handling for failed refreshes
@@ -511,9 +509,8 @@ UI initialization and setup. Low priority.
**In Progress:** **In Progress:**
9. 🔴 `pkg/appview/db/*` - Database layer (41.2%, needs improvement) 9. 🔴 `pkg/appview/db/*` - Database layer (41.2%, needs improvement)
- queries.go, session_store.go, device_store.go - queries.go, session_store.go, device_store.go
10. 🔴 `pkg/auth/oauth/client.go` - Session management (Refresher) (Partial → 70%+) 10. 🔴 `pkg/auth/oauth/refresher.go` - Token refresh (Partial → 70%+)
- `GetSession()`, `resumeSession()` (currently 0%) - `GetSession()`, `resumeSession()` (currently 0%)
- Note: Refresher merged into client.go
11. 🔴 `pkg/auth/oauth/server.go` - OAuth endpoints (50.7%, continue improvements) 11. 🔴 `pkg/auth/oauth/server.go` - OAuth endpoints (50.7%, continue improvements)
- `ServeCallback()` at 16.3% needs major improvement - `ServeCallback()` at 16.3% needs major improvement
12. 🔴 `pkg/appview/storage/crew.go` - Crew validation (11.1% → 80%+) 12. 🔴 `pkg/appview/storage/crew.go` - Crew validation (11.1% → 80%+)

2
go.mod
View File

@@ -4,7 +4,7 @@ go 1.24.7
require ( require (
github.com/aws/aws-sdk-go v1.55.5 github.com/aws/aws-sdk-go v1.55.5
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61 github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb
github.com/distribution/distribution/v3 v3.0.0 github.com/distribution/distribution/v3 v3.0.0
github.com/distribution/reference v0.6.0 github.com/distribution/reference v0.6.0
github.com/earthboundkid/versioninfo/v2 v2.24.1 github.com/earthboundkid/versioninfo/v2 v2.24.1

4
go.sum
View File

@@ -20,8 +20,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61 h1:lU2NnyuvevVWtE35sb4xWBp1AQxa1Sv4XhexiWlrWng= github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb h1:zzyqB1W/itfdIA5cnOZ7IFCJ6QtqwOsXltmLunL4sHw=
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng= github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70= github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70=

View File

@@ -112,25 +112,6 @@ func (s *OAuthStore) DeleteSessionsForDID(ctx context.Context, did string) error
return nil return nil
} }
// DeleteOldSessionsForDID removes all sessions for a DID except the specified session to keep
// This is used during OAuth callback to clean up stale sessions with expired refresh tokens
func (s *OAuthStore) DeleteOldSessionsForDID(ctx context.Context, did string, keepSessionID string) error {
result, err := s.db.ExecContext(ctx, `
DELETE FROM oauth_sessions WHERE account_did = ? AND session_id != ?
`, did, keepSessionID)
if err != nil {
return fmt.Errorf("failed to delete old sessions for DID: %w", err)
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Deleted old OAuth sessions for DID", "count", deleted, "did", did, "kept", keepSessionID)
}
return nil
}
// GetAuthRequestInfo retrieves authentication request data by state // GetAuthRequestInfo retrieves authentication request data by state
func (s *OAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { func (s *OAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
var requestDataJSON string var requestDataJSON string

View File

@@ -724,30 +724,6 @@ func GetNewestManifestForRepo(db *sql.DB, did, repository string) (*Manifest, er
return &m, nil return &m, nil
} }
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
// Returns empty string if no manifests exist (e.g., first push)
// This is used instead of the in-memory cache to determine which hold to use for blob operations
func GetLatestHoldDIDForRepo(db *sql.DB, did, repository string) (string, error) {
var holdDID string
err := db.QueryRow(`
SELECT hold_endpoint
FROM manifests
WHERE did = ? AND repository = ?
ORDER BY created_at DESC
LIMIT 1
`, did, repository).Scan(&holdDID)
if err == sql.ErrNoRows {
// No manifests yet - return empty string (first push case)
return "", nil
}
if err != nil {
return "", err
}
return holdDID, nil
}
// GetRepositoriesForDID returns all unique repository names for a DID // GetRepositoriesForDID returns all unique repository names for a DID
// Used by backfill to reconcile annotations for all repositories // Used by backfill to reconcile annotations for all repositories
func GetRepositoriesForDID(db *sql.DB, did string) ([]string, error) { func GetRepositoriesForDID(db *sql.DB, did string) ([]string, error) {
@@ -1600,11 +1576,6 @@ func (m *MetricsDB) IncrementPushCount(did, repository string) error {
return IncrementPushCount(m.db, did, repository) return IncrementPushCount(m.db, did, repository)
} }
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
return GetLatestHoldDIDForRepo(m.db, did, repository)
}
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls // GetFeaturedRepositories fetches top repositories sorted by stars and pulls
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) { func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
query := ` query := `

View File

@@ -6,16 +6,15 @@ import (
"atcr.io/pkg/appview/db" "atcr.io/pkg/appview/db"
"atcr.io/pkg/auth/oauth" "atcr.io/pkg/auth/oauth"
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax" "github.com/bluesky-social/indigo/atproto/syntax"
) )
// LogoutHandler handles user logout with proper OAuth token revocation // LogoutHandler handles user logout with proper OAuth token revocation
type LogoutHandler struct { type LogoutHandler struct {
OAuthClientApp *indigooauth.ClientApp OAuthApp *oauth.App
Refresher *oauth.Refresher Refresher *oauth.Refresher
SessionStore *db.SessionStore SessionStore *db.SessionStore
OAuthStore *db.OAuthStore OAuthStore *db.OAuthStore
} }
func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -38,13 +37,17 @@ func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Attempt to revoke OAuth tokens on PDS side // Attempt to revoke OAuth tokens on PDS side
if uiSession.OAuthSessionID != "" { if uiSession.OAuthSessionID != "" {
// Call indigo's Logout to revoke tokens on PDS // Call indigo's Logout to revoke tokens on PDS
if err := h.OAuthClientApp.Logout(r.Context(), did, uiSession.OAuthSessionID); err != nil { if err := h.OAuthApp.GetClientApp().Logout(r.Context(), did, uiSession.OAuthSessionID); err != nil {
// Log error but don't block logout - best effort revocation // Log error but don't block logout - best effort revocation
slog.Warn("Failed to revoke OAuth tokens on PDS", "component", "logout", "did", uiSession.DID, "error", err) slog.Warn("Failed to revoke OAuth tokens on PDS", "component", "logout", "did", uiSession.DID, "error", err)
} else { } else {
slog.Info("Successfully revoked OAuth tokens on PDS", "component", "logout", "did", uiSession.DID) slog.Info("Successfully revoked OAuth tokens on PDS", "component", "logout", "did", uiSession.DID)
} }
// Invalidate refresher cache to clear local access tokens
h.Refresher.InvalidateSession(uiSession.DID)
slog.Info("Invalidated local OAuth cache", "component", "logout", "did", uiSession.DID)
// Delete OAuth session from database (cleanup, might already be done by Logout) // Delete OAuth session from database (cleanup, might already be done by Logout)
if err := h.OAuthStore.DeleteSession(r.Context(), did, uiSession.OAuthSessionID); err != nil { if err := h.OAuthStore.DeleteSession(r.Context(), did, uiSession.OAuthSessionID); err != nil {
slog.Warn("Failed to delete OAuth session from database", "component", "logout", "error", err) slog.Warn("Failed to delete OAuth session from database", "component", "logout", "error", err)

View File

@@ -107,15 +107,6 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
// Detect manifest type // Detect manifest type
isManifestList := len(manifestRecord.Manifests) > 0 isManifestList := len(manifestRecord.Manifests) > 0
// Extract hold DID from manifest (with fallback for legacy manifests)
// New manifests use holdDid field (DID format)
// Old manifests use holdEndpoint field (URL format) - convert to DID
holdDID := manifestRecord.HoldDID
if holdDID == "" && manifestRecord.HoldEndpoint != "" {
// Legacy manifest - convert URL to DID
holdDID = atproto.ResolveHoldDIDFromURL(manifestRecord.HoldEndpoint)
}
// Prepare manifest for insertion (WITHOUT annotation fields) // Prepare manifest for insertion (WITHOUT annotation fields)
manifest := &db.Manifest{ manifest := &db.Manifest{
DID: did, DID: did,
@@ -123,7 +114,7 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
Digest: manifestRecord.Digest, Digest: manifestRecord.Digest,
MediaType: manifestRecord.MediaType, MediaType: manifestRecord.MediaType,
SchemaVersion: manifestRecord.SchemaVersion, SchemaVersion: manifestRecord.SchemaVersion,
HoldEndpoint: holdDID, HoldEndpoint: manifestRecord.HoldEndpoint,
CreatedAt: manifestRecord.CreatedAt, CreatedAt: manifestRecord.CreatedAt,
// Annotations removed - stored separately in repository_annotations table // Annotations removed - stored separately in repository_annotations table
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"strings" "strings"
"sync"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
@@ -68,6 +69,7 @@ type NamespaceResolver struct {
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io") defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
baseURL string // Base URL for error messages (e.g., "https://atcr.io") baseURL string // Base URL for error messages (e.g., "https://atcr.io")
testMode bool // If true, fallback to default hold when user's hold is unreachable testMode bool // If true, fallback to default hold when user's hold is unreachable
repositories sync.Map // Cache of RoutingRepository instances by key (did:reponame)
refresher *oauth.Refresher // OAuth session manager (copied from global on init) refresher *oauth.Refresher // OAuth session manager (copied from global on init)
database storage.DatabaseMetrics // Metrics database (copied from global on init) database storage.DatabaseMetrics // Metrics database (copied from global on init)
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init) authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
@@ -222,15 +224,20 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
// Example: "evan.jarrett.net/debian" -> store as "debian" // Example: "evan.jarrett.net/debian" -> store as "debian"
repositoryName := imageName repositoryName := imageName
// Cache key is DID + repository name
cacheKey := did + ":" + repositoryName
// Check cache first and update service token
if cached, ok := nr.repositories.Load(cacheKey); ok {
cachedRepo := cached.(*storage.RoutingRepository)
// Always update the service token even for cached repos (token may have been renewed)
cachedRepo.Ctx.ServiceToken = serviceToken
return cachedRepo, nil
}
// Create routing repository - routes manifests to ATProto, blobs to hold service // Create routing repository - routes manifests to ATProto, blobs to hold service
// The registry is stateless - no local storage is used // The registry is stateless - no local storage is used
// Bundle all context into a single RegistryContext struct // Bundle all context into a single RegistryContext struct
//
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
// 1. Each layer upload is a separate HTTP request (possibly different process)
// 2. OAuth sessions can be refreshed/invalidated between requests
// 3. The refresher already caches sessions efficiently (in-memory + DB)
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
registryCtx := &storage.RegistryContext{ registryCtx := &storage.RegistryContext{
DID: did, DID: did,
Handle: handle, Handle: handle,
@@ -244,8 +251,12 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
Refresher: nr.refresher, Refresher: nr.refresher,
ReadmeCache: nr.readmeCache, ReadmeCache: nr.readmeCache,
} }
routingRepo := storage.NewRoutingRepository(repo, registryCtx)
return storage.NewRoutingRepository(repo, registryCtx), nil // Cache the repository
nr.repositories.Store(cacheKey, routingRepo)
return routingRepo, nil
} }
// Repositories delegates to underlying namespace // Repositories delegates to underlying namespace

View File

@@ -13,22 +13,21 @@ import (
"atcr.io/pkg/appview/readme" "atcr.io/pkg/appview/readme"
"atcr.io/pkg/auth/oauth" "atcr.io/pkg/auth/oauth"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
) )
// UIDependencies contains all dependencies needed for UI route registration // UIDependencies contains all dependencies needed for UI route registration
type UIDependencies struct { type UIDependencies struct {
Database *sql.DB Database *sql.DB
ReadOnlyDB *sql.DB ReadOnlyDB *sql.DB
SessionStore *db.SessionStore SessionStore *db.SessionStore
OAuthClientApp *indigooauth.ClientApp OAuthApp *oauth.App
OAuthStore *db.OAuthStore OAuthStore *db.OAuthStore
Refresher *oauth.Refresher Refresher *oauth.Refresher
BaseURL string BaseURL string
DeviceStore *db.DeviceStore DeviceStore *db.DeviceStore
HealthChecker *holdhealth.Checker HealthChecker *holdhealth.Checker
ReadmeCache *readme.Cache ReadmeCache *readme.Cache
Templates *template.Template Templates *template.Template
} }
// RegisterUIRoutes registers all web UI and API routes on the provided router // RegisterUIRoutes registers all web UI and API routes on the provided router
@@ -91,7 +90,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)( router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
&uihandlers.GetStatsHandler{ &uihandlers.GetStatsHandler{
DB: deps.ReadOnlyDB, DB: deps.ReadOnlyDB,
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
}, },
).ServeHTTP) ).ServeHTTP)
@@ -99,7 +98,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)( router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
&uihandlers.StarRepositoryHandler{ &uihandlers.StarRepositoryHandler{
DB: deps.Database, // Needs write access DB: deps.Database, // Needs write access
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
Refresher: deps.Refresher, Refresher: deps.Refresher,
}, },
).ServeHTTP) ).ServeHTTP)
@@ -107,7 +106,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)( router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
&uihandlers.UnstarRepositoryHandler{ &uihandlers.UnstarRepositoryHandler{
DB: deps.Database, // Needs write access DB: deps.Database, // Needs write access
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
Refresher: deps.Refresher, Refresher: deps.Refresher,
}, },
).ServeHTTP) ).ServeHTTP)
@@ -115,7 +114,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)( router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
&uihandlers.CheckStarHandler{ &uihandlers.CheckStarHandler{
DB: deps.ReadOnlyDB, // Read-only check DB: deps.ReadOnlyDB, // Read-only check
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
Refresher: deps.Refresher, Refresher: deps.Refresher,
}, },
).ServeHTTP) ).ServeHTTP)
@@ -124,7 +123,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)( router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
&uihandlers.ManifestDetailHandler{ &uihandlers.ManifestDetailHandler{
DB: deps.ReadOnlyDB, DB: deps.ReadOnlyDB,
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
}, },
).ServeHTTP) ).ServeHTTP)
@@ -146,7 +145,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
DB: deps.ReadOnlyDB, DB: deps.ReadOnlyDB,
Templates: deps.Templates, Templates: deps.Templates,
RegistryURL: registryURL, RegistryURL: registryURL,
Directory: deps.OAuthClientApp.Dir, Directory: deps.OAuthApp.Directory(),
Refresher: deps.Refresher, Refresher: deps.Refresher,
HealthChecker: deps.HealthChecker, HealthChecker: deps.HealthChecker,
ReadmeCache: deps.ReadmeCache, ReadmeCache: deps.ReadmeCache,
@@ -203,10 +202,10 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
// Logout endpoint (supports both GET and POST) // Logout endpoint (supports both GET and POST)
// Properly revokes OAuth tokens on PDS side before clearing local session // Properly revokes OAuth tokens on PDS side before clearing local session
logoutHandler := &uihandlers.LogoutHandler{ logoutHandler := &uihandlers.LogoutHandler{
OAuthClientApp: deps.OAuthClientApp, OAuthApp: deps.OAuthApp,
Refresher: deps.Refresher, Refresher: deps.Refresher,
SessionStore: deps.SessionStore, SessionStore: deps.SessionStore,
OAuthStore: deps.OAuthStore, OAuthStore: deps.OAuthStore,
} }
router.Get("/auth/logout", logoutHandler.ServeHTTP) router.Get("/auth/logout", logoutHandler.ServeHTTP)
router.Post("/auth/logout", logoutHandler.ServeHTTP) router.Post("/auth/logout", logoutHandler.ServeHTTP)

View File

@@ -8,11 +8,10 @@ import (
"atcr.io/pkg/auth/oauth" "atcr.io/pkg/auth/oauth"
) )
// DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs // DatabaseMetrics interface for tracking pull/push counts
type DatabaseMetrics interface { type DatabaseMetrics interface {
IncrementPullCount(did, repository string) error IncrementPullCount(did, repository string) error
IncrementPushCount(did, repository string) error IncrementPushCount(did, repository string) error
GetLatestHoldDIDForRepo(did, repository string) (string, error)
} }
// ReadmeCache interface for README content caching // ReadmeCache interface for README content caching

View File

@@ -29,11 +29,6 @@ func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error {
return nil return nil
} }
func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
// Return empty string for mock - tests can override if needed
return "", nil
}
func (m *mockDatabaseMetrics) getPullCount() int { func (m *mockDatabaseMetrics) getPullCount() int {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()

View File

@@ -0,0 +1,98 @@
package storage
import (
"sync"
"time"
)
// HoldCache caches hold DIDs for (DID, repository) pairs
// This avoids expensive ATProto lookups on every blob request during pulls
//
// NOTE: This is a simple in-memory cache for MVP. For production deployments:
// - Use Redis or similar for distributed caching
// - Consider implementing cache size limits
// - Monitor memory usage under high load
type HoldCache struct {
mu sync.RWMutex
cache map[string]*holdCacheEntry
}
type holdCacheEntry struct {
holdDID string
expiresAt time.Time
}
var globalHoldCache = &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
func init() {
// Start background cleanup goroutine
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
globalHoldCache.Cleanup()
}
}()
}
// GetGlobalHoldCache returns the global hold cache instance
func GetGlobalHoldCache() *HoldCache {
return globalHoldCache
}
// Set stores a hold DID for a (DID, repository) pair with a TTL
func (c *HoldCache) Set(did, repository, holdDID string, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
key := did + ":" + repository
c.cache[key] = &holdCacheEntry{
holdDID: holdDID,
expiresAt: time.Now().Add(ttl),
}
}
// Get retrieves a hold DID for a (DID, repository) pair
// Returns empty string and false if not found or expired
func (c *HoldCache) Get(did, repository string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
key := did + ":" + repository
entry, ok := c.cache[key]
if !ok {
return "", false
}
// Check if expired
if time.Now().After(entry.expiresAt) {
// Don't delete here (would need write lock), let cleanup handle it
return "", false
}
return entry.holdDID, true
}
// Cleanup removes expired entries (called automatically every 5 minutes)
func (c *HoldCache) Cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
removed := 0
for key, entry := range c.cache {
if now.After(entry.expiresAt) {
delete(c.cache, key)
removed++
}
}
// Log cleanup stats for monitoring
if removed > 0 || len(c.cache) > 100 {
// Log if we removed entries OR if cache is growing large
// This helps identify if cache size is becoming a concern
println("Hold cache cleanup: removed", removed, "entries, remaining", len(c.cache))
}
}

View File

@@ -0,0 +1,150 @@
package storage
import (
"testing"
"time"
)
func TestHoldCache_SetAndGet(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
did := "did:plc:test123"
repo := "myapp"
holdDID := "did:web:hold01.atcr.io"
ttl := 10 * time.Minute
// Set a value
cache.Set(did, repo, holdDID, ttl)
// Get the value - should succeed
gotHoldDID, ok := cache.Get(did, repo)
if !ok {
t.Fatal("Expected Get to return true, got false")
}
if gotHoldDID != holdDID {
t.Errorf("Expected hold DID %q, got %q", holdDID, gotHoldDID)
}
}
func TestHoldCache_GetNonExistent(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
// Get non-existent value
_, ok := cache.Get("did:plc:nonexistent", "repo")
if ok {
t.Error("Expected Get to return false for non-existent key")
}
}
func TestHoldCache_ExpiredEntry(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
did := "did:plc:test123"
repo := "myapp"
holdDID := "did:web:hold01.atcr.io"
// Set with very short TTL
cache.Set(did, repo, holdDID, 10*time.Millisecond)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Get should return false
_, ok := cache.Get(did, repo)
if ok {
t.Error("Expected Get to return false for expired entry")
}
}
func TestHoldCache_Cleanup(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
// Add multiple entries with different TTLs
cache.Set("did:plc:1", "repo1", "hold1", 10*time.Millisecond)
cache.Set("did:plc:2", "repo2", "hold2", 1*time.Hour)
cache.Set("did:plc:3", "repo3", "hold3", 10*time.Millisecond)
// Wait for some to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Verify expired entries are removed
if _, ok := cache.Get("did:plc:1", "repo1"); ok {
t.Error("Expected expired entry 1 to be removed")
}
if _, ok := cache.Get("did:plc:3", "repo3"); ok {
t.Error("Expected expired entry 3 to be removed")
}
// Verify non-expired entry remains
if _, ok := cache.Get("did:plc:2", "repo2"); !ok {
t.Error("Expected non-expired entry to remain")
}
}
func TestHoldCache_ConcurrentAccess(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
done := make(chan bool)
// Concurrent writes
for i := 0; i < 10; i++ {
go func(id int) {
did := "did:plc:concurrent"
repo := "repo" + string(rune(id))
holdDID := "hold" + string(rune(id))
cache.Set(did, repo, holdDID, 1*time.Minute)
done <- true
}(i)
}
// Concurrent reads
for i := 0; i < 10; i++ {
go func(id int) {
repo := "repo" + string(rune(id))
cache.Get("did:plc:concurrent", repo)
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 20; i++ {
<-done
}
}
func TestHoldCache_KeyFormat(t *testing.T) {
cache := &HoldCache{
cache: make(map[string]*holdCacheEntry),
}
did := "did:plc:test"
repo := "myrepo"
holdDID := "did:web:hold"
cache.Set(did, repo, holdDID, 1*time.Minute)
// Verify the key is stored correctly (did:repo)
expectedKey := did + ":" + repo
if _, exists := cache.cache[expectedKey]; !exists {
t.Errorf("Expected key %q to exist in cache", expectedKey)
}
}
// TODO: Add more comprehensive tests:
// - Test GetGlobalHoldCache()
// - Test cache size monitoring
// - Benchmark cache performance under load
// - Test cleanup goroutine timing

View File

@@ -1,6 +1,6 @@
// Package storage implements the storage routing layer for AppView. // Package storage implements the storage routing layer for AppView.
// It routes manifests to ATProto PDS (as io.atcr.manifest records) and // It routes manifests to ATProto PDS (as io.atcr.manifest records) and
// blobs to hold services via XRPC, with database-based hold DID lookups. // blobs to hold services via XRPC, with hold DID caching for efficient pulls.
// All storage operations are proxied - AppView stores nothing locally. // All storage operations are proxied - AppView stores nothing locally.
package storage package storage
@@ -8,6 +8,7 @@ import (
"context" "context"
"log/slog" "log/slog"
"sync" "sync"
"time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
) )
@@ -49,6 +50,17 @@ func (r *RoutingRepository) Manifests(ctx context.Context, options ...distributi
manifestStore := r.manifestStore manifestStore := r.manifestStore
r.mu.Unlock() r.mu.Unlock()
// After any manifest operation, cache the hold DID for blob fetches
// We use a goroutine to avoid blocking, and check after a short delay to allow the operation to complete
go func() {
time.Sleep(100 * time.Millisecond) // Brief delay to let manifest fetch complete
if holdDID := manifestStore.GetLastFetchedHoldDID(); holdDID != "" {
// Cache for 10 minutes - should cover typical pull operations
GetGlobalHoldCache().Set(r.Ctx.DID, r.Ctx.Repository, holdDID, 10*time.Minute)
slog.Debug("Cached hold DID", "component", "storage/routing", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID)
}
}()
return manifestStore, nil return manifestStore, nil
} }
@@ -64,23 +76,17 @@ func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
return blobStore return blobStore
} }
// For pull operations, check database for hold DID from the most recent manifest // For pull operations, check if we have a cached hold DID from a recent manifest fetch
// This ensures blobs are fetched from the hold recorded in the manifest, not re-discovered // This ensures blobs are fetched from the hold recorded in the manifest, not re-discovered
holdDID := r.Ctx.HoldDID // Default to discovery-based DID holdDID := r.Ctx.HoldDID // Default to discovery-based DID
holdSource := "discovery"
if r.Ctx.Database != nil { if cachedHoldDID, ok := GetGlobalHoldCache().Get(r.Ctx.DID, r.Ctx.Repository); ok {
// Query database for the latest manifest's hold DID // Use cached hold DID from manifest
if dbHoldDID, err := r.Ctx.Database.GetLatestHoldDIDForRepo(r.Ctx.DID, r.Ctx.Repository); err == nil && dbHoldDID != "" { holdDID = cachedHoldDID
// Use hold DID from database (pull case - use historical reference) slog.Debug("Using cached hold from manifest", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", cachedHoldDID)
holdDID = dbHoldDID } else {
holdSource = "database" // No cached hold, use discovery-based DID (for push or first pull)
slog.Debug("Using hold from database manifest", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", dbHoldDID) slog.Debug("Using discovery-based hold", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID)
} else if err != nil {
// Log error but don't fail - fall back to discovery-based DID
slog.Warn("Failed to query database for hold DID", "component", "storage/blobs", "error", err)
}
// If dbHoldDID is empty (no manifests yet), fall through to use discovery-based DID
} }
if holdDID == "" { if holdDID == "" {
@@ -88,9 +94,7 @@ func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware") panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
} }
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID, "source", holdSource) // Update context with the correct hold DID (may be cached or discovered)
// Update context with the correct hold DID (may be from database or discovered)
r.Ctx.HoldDID = holdDID r.Ctx.HoldDID = holdDID
// Create and cache proxy blob store // Create and cache proxy blob store

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"sync" "sync"
"testing" "testing"
"time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,27 +13,6 @@ import (
"atcr.io/pkg/atproto" "atcr.io/pkg/atproto"
) )
// mockDatabase is a simple mock for testing
type mockDatabase struct {
holdDID string
err error
}
func (m *mockDatabase) IncrementPullCount(did, repository string) error {
return nil
}
func (m *mockDatabase) IncrementPushCount(did, repository string) error {
return nil
}
func (m *mockDatabase) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
if m.err != nil {
return "", m.err
}
return m.holdDID, nil
}
func TestNewRoutingRepository(t *testing.T) { func TestNewRoutingRepository(t *testing.T) {
ctx := &RegistryContext{ ctx := &RegistryContext{
DID: "did:plc:test123", DID: "did:plc:test123",
@@ -109,36 +89,38 @@ func TestRoutingRepository_ManifestStoreCaching(t *testing.T) {
assert.NotNil(t, repo.manifestStore) assert.NotNil(t, repo.manifestStore)
} }
// TestRoutingRepository_Blobs_WithDatabase tests blob store with database hold DID // TestRoutingRepository_Blobs_WithCache tests blob store with cached hold DID
func TestRoutingRepository_Blobs_WithDatabase(t *testing.T) { func TestRoutingRepository_Blobs_WithCache(t *testing.T) {
dbHoldDID := "did:web:database.hold.io" // Pre-populate the hold cache
cache := GetGlobalHoldCache()
cachedHoldDID := "did:web:cached.hold.io"
cache.Set("did:plc:test123", "myapp", cachedHoldDID, 10*time.Minute)
ctx := &RegistryContext{ ctx := &RegistryContext{
DID: "did:plc:test123", DID: "did:plc:test123",
Repository: "myapp", Repository: "myapp",
HoldDID: "did:web:default.hold.io", // Discovery-based hold (should be overridden) HoldDID: "did:web:default.hold.io", // Discovery-based hold (should be overridden)
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""), ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID},
} }
repo := NewRoutingRepository(nil, ctx) repo := NewRoutingRepository(nil, ctx)
blobStore := repo.Blobs(context.Background()) blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore) assert.NotNil(t, blobStore)
// Verify the hold DID was updated to use the database value // Verify the hold DID was updated to use the cached value
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "should use database hold DID") assert.Equal(t, cachedHoldDID, repo.Ctx.HoldDID, "should use cached hold DID")
} }
// TestRoutingRepository_Blobs_WithoutDatabase tests blob store with discovery-based hold // TestRoutingRepository_Blobs_WithoutCache tests blob store with discovery-based hold
func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) { func TestRoutingRepository_Blobs_WithoutCache(t *testing.T) {
discoveryHoldDID := "did:web:discovery.hold.io" discoveryHoldDID := "did:web:discovery.hold.io"
// Use a different DID/repo to avoid cache contamination from other tests
ctx := &RegistryContext{ ctx := &RegistryContext{
DID: "did:plc:nocache456", DID: "did:plc:nocache456",
Repository: "uncached-app", Repository: "uncached-app",
HoldDID: discoveryHoldDID, HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""), ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
Database: nil, // No database
} }
repo := NewRoutingRepository(nil, ctx) repo := NewRoutingRepository(nil, ctx)
@@ -149,26 +131,6 @@ func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID") assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
} }
// TestRoutingRepository_Blobs_DatabaseEmptyFallback tests fallback when database returns empty hold DID
func TestRoutingRepository_Blobs_DatabaseEmptyFallback(t *testing.T) {
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{
DID: "did:plc:test123",
Repository: "newapp",
HoldDID: discoveryHoldDID,
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: ""}, // Empty string (no manifests yet)
}
repo := NewRoutingRepository(nil, ctx)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore)
// Verify the hold DID falls back to discovery-based
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should fall back to discovery-based hold DID when database returns empty")
}
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached // TestRoutingRepository_BlobStoreCaching tests that blob store is cached
func TestRoutingRepository_BlobStoreCaching(t *testing.T) { func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
ctx := &RegistryContext{ ctx := &RegistryContext{
@@ -292,23 +254,26 @@ func TestRoutingRepository_ConcurrentAccess(t *testing.T) {
assert.NotNil(t, cachedBlobStore) assert.NotNil(t, cachedBlobStore)
} }
// TestRoutingRepository_Blobs_Priority tests that database hold DID takes priority over discovery // TestRoutingRepository_HoldCachePopulation tests that hold DID cache is populated after manifest fetch
func TestRoutingRepository_Blobs_Priority(t *testing.T) { // Note: This test verifies the goroutine behavior with a delay
dbHoldDID := "did:web:database.hold.io" func TestRoutingRepository_HoldCachePopulation(t *testing.T) {
discoveryHoldDID := "did:web:discovery.hold.io"
ctx := &RegistryContext{ ctx := &RegistryContext{
DID: "did:plc:test123", DID: "did:plc:test123",
Repository: "myapp", Repository: "myapp",
HoldDID: discoveryHoldDID, // Discovery-based hold HoldDID: "did:web:hold01.atcr.io",
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""), ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
Database: &mockDatabase{holdDID: dbHoldDID}, // Database has a different hold DID
} }
repo := NewRoutingRepository(nil, ctx) repo := NewRoutingRepository(nil, ctx)
blobStore := repo.Blobs(context.Background())
assert.NotNil(t, blobStore) // Create manifest store (which triggers the cache population goroutine)
// Database hold DID should take priority over discovery _, err := repo.Manifests(context.Background())
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "database hold DID should take priority over discovery") require.NoError(t, err)
// Wait for goroutine to complete (it has a 100ms sleep)
time.Sleep(200 * time.Millisecond)
// Note: We can't easily verify the cache was populated without a real manifest fetch
// The actual caching happens in GetLastFetchedHoldDID() which requires manifest operations
// This test primarily verifies the Manifests() call doesn't panic with the goroutine
} }

View File

@@ -1,6 +1,6 @@
// Package oauth provides OAuth client configuration and helper functions for ATCR. // Package oauth provides OAuth client and flow implementation for ATCR.
// It provides helpers for setting up indigo's OAuth library with ATCR-specific // It wraps indigo's OAuth library with ATCR-specific configuration,
// configuration, including default scopes, confidential client setup, and // including default scopes, client metadata, token refreshing, and
// interactive browser-based authentication flows. // interactive browser-based authentication flows.
package oauth package oauth
@@ -8,19 +8,31 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"net/url"
"strings" "strings"
"time"
"atcr.io/pkg/atproto" "atcr.io/pkg/atproto"
"github.com/bluesky-social/indigo/atproto/auth/oauth" "github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/identity"
"github.com/bluesky-social/indigo/atproto/syntax" "github.com/bluesky-social/indigo/atproto/syntax"
) )
// NewClientApp creates an indigo OAuth ClientApp with ATCR-specific configuration // App wraps indigo's ClientApp with ATCR-specific configuration
type App struct {
clientApp *oauth.ClientApp
baseURL string
}
// NewApp creates a new OAuth app for ATCR with default scopes
func NewApp(baseURL string, store oauth.ClientAuthStore, holdDid string, keyPath string, clientName string) (*App, error) {
return NewAppWithScopes(baseURL, store, GetDefaultScopes(holdDid), keyPath, clientName)
}
// NewAppWithScopes creates a new OAuth app for ATCR with custom scopes
// Automatically configures confidential client for production deployments // Automatically configures confidential client for production deployments
// keyPath specifies where to store/load the OAuth client P-256 key (ignored for localhost) // keyPath specifies where to store/load the OAuth client P-256 key (ignored for localhost)
// clientName is added to OAuth client metadata (currently unused, reserved for future) // clientName is added to OAuth client metadata
func NewClientApp(baseURL string, store oauth.ClientAuthStore, scopes []string, keyPath string, clientName string) (*oauth.ClientApp, error) { func NewAppWithScopes(baseURL string, store oauth.ClientAuthStore, scopes []string, keyPath string, clientName string) (*App, error) {
var config oauth.ClientConfig var config oauth.ClientConfig
redirectURI := RedirectURI(baseURL) redirectURI := RedirectURI(baseURL)
@@ -56,7 +68,60 @@ func NewClientApp(baseURL string, store oauth.ClientAuthStore, scopes []string,
clientApp := oauth.NewClientApp(&config, store) clientApp := oauth.NewClientApp(&config, store)
clientApp.Dir = atproto.GetDirectory() clientApp.Dir = atproto.GetDirectory()
return clientApp, nil return &App{
clientApp: clientApp,
baseURL: baseURL,
}, nil
}
func (a *App) GetConfig() *oauth.ClientConfig {
return a.clientApp.Config
}
// StartAuthFlow initiates an OAuth authorization flow for a given handle
// Returns the authorization URL (state is stored in the auth store)
func (a *App) StartAuthFlow(ctx context.Context, handle string) (authURL string, err error) {
// Start auth flow with handle as identifier
// Indigo will resolve the handle internally
authURL, err = a.clientApp.StartAuthFlow(ctx, handle)
if err != nil {
return "", fmt.Errorf("failed to start auth flow: %w", err)
}
return authURL, nil
}
// ProcessCallback processes an OAuth callback with authorization code and state
// Returns ClientSessionData which contains the session information
func (a *App) ProcessCallback(ctx context.Context, params url.Values) (*oauth.ClientSessionData, error) {
sessionData, err := a.clientApp.ProcessCallback(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to process OAuth callback: %w", err)
}
return sessionData, nil
}
// ResumeSession resumes an existing OAuth session
// Returns a ClientSession that can be used to make authenticated requests
func (a *App) ResumeSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSession, error) {
session, err := a.clientApp.ResumeSession(ctx, did, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to resume session: %w", err)
}
return session, nil
}
// GetClientApp returns the underlying indigo ClientApp
// This is useful for advanced use cases that need direct access
func (a *App) GetClientApp() *oauth.ClientApp {
return a.clientApp
}
// Directory returns the identity directory used by the OAuth app
func (a *App) Directory() identity.Directory {
return a.clientApp.Dir
} }
// RedirectURI returns the OAuth redirect URI for ATCR // RedirectURI returns the OAuth redirect URI for ATCR
@@ -123,111 +188,3 @@ func ScopesMatch(stored, desired []string) bool {
func isLocalhost(baseURL string) bool { func isLocalhost(baseURL string) bool {
return strings.Contains(baseURL, "127.0.0.1") || strings.Contains(baseURL, "localhost") return strings.Contains(baseURL, "127.0.0.1") || strings.Contains(baseURL, "localhost")
} }
// ----------------------------------------------------------------------------
// Session Management
// ----------------------------------------------------------------------------
// SessionCache represents a cached OAuth session
type SessionCache struct {
Session *oauth.ClientSession
SessionID string
}
// UISessionStore interface for managing UI sessions
// Shared between refresher and server
type UISessionStore interface {
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
DeleteByDID(did string)
}
// Refresher manages OAuth sessions and token refresh for AppView
// Sessions are loaded fresh from database on every request (database is source of truth)
type Refresher struct {
clientApp *oauth.ClientApp
uiSessionStore UISessionStore // For invalidating UI sessions on OAuth failures
}
// NewRefresher creates a new session refresher
func NewRefresher(clientApp *oauth.ClientApp) *Refresher {
return &Refresher{
clientApp: clientApp,
}
}
// SetUISessionStore sets the UI session store for invalidating sessions on OAuth failures
func (r *Refresher) SetUISessionStore(store UISessionStore) {
r.uiSessionStore = store
}
// GetSession gets a fresh OAuth session for a DID
// Loads session from database on every request (database is source of truth)
func (r *Refresher) GetSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
return r.resumeSession(ctx, did)
}
// resumeSession loads a session from storage
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
// Parse DID
accountDID, err := syntax.ParseDID(did)
if err != nil {
return nil, fmt.Errorf("failed to parse DID: %w", err)
}
// Get the latest session for this DID from SQLite store
// The store must implement GetLatestSessionForDID (returns newest by updated_at)
type sessionGetter interface {
GetLatestSessionForDID(ctx context.Context, did string) (*oauth.ClientSessionData, string, error)
}
getter, ok := r.clientApp.Store.(sessionGetter)
if !ok {
return nil, fmt.Errorf("store must implement GetLatestSessionForDID (SQLite store required)")
}
sessionData, sessionID, err := getter.GetLatestSessionForDID(ctx, did)
if err != nil {
return nil, fmt.Errorf("no session found for DID: %s", did)
}
// Validate that session scopes match current desired scopes
desiredScopes := r.clientApp.Config.Scopes
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
slog.Debug("Scope mismatch, deleting session",
"did", did,
"storedScopes", sessionData.Scopes,
"desiredScopes", desiredScopes)
// Delete the session from database since scopes have changed
if err := r.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
}
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
}
// Resume session
session, err := r.clientApp.ResumeSession(ctx, accountDID, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to resume session: %w", err)
}
// Set up callback to persist token updates to SQLite
// This ensures that when indigo automatically refreshes tokens,
// the new tokens are saved to the database immediately
session.PersistSessionCallback = func(callbackCtx context.Context, updatedData *oauth.ClientSessionData) {
if err := r.clientApp.Store.SaveSession(callbackCtx, *updatedData); err != nil {
slog.Error("Failed to persist OAuth session update",
"component", "oauth/refresher",
"did", did,
"sessionID", sessionID,
"error", err)
} else {
slog.Debug("Persisted OAuth token refresh to database",
"component", "oauth/refresher",
"did", did,
"sessionID", sessionID)
}
}
return session, nil
}

View File

@@ -4,7 +4,7 @@ import (
"testing" "testing"
) )
func TestNewClientApp(t *testing.T) { func TestNewApp(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json" storePath := tmpDir + "/oauth-test.json"
keyPath := tmpDir + "/oauth-key.bin" keyPath := tmpDir + "/oauth-key.bin"
@@ -15,23 +15,23 @@ func TestNewClientApp(t *testing.T) {
} }
baseURL := "http://localhost:5000" baseURL := "http://localhost:5000"
scopes := GetDefaultScopes("*") holdDID := "did:web:hold.example.com"
clientApp, err := NewClientApp(baseURL, store, scopes, keyPath, "AT Container Registry") app, err := NewApp(baseURL, store, holdDID, keyPath, "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
if clientApp == nil { if app == nil {
t.Fatal("Expected non-nil clientApp") t.Fatal("Expected non-nil app")
} }
if clientApp.Dir == nil { if app.baseURL != baseURL {
t.Error("Expected directory to be set") t.Errorf("Expected baseURL %q, got %q", baseURL, app.baseURL)
} }
} }
func TestNewClientAppWithCustomScopes(t *testing.T) { func TestNewAppWithScopes(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json" storePath := tmpDir + "/oauth-test.json"
keyPath := tmpDir + "/oauth-key.bin" keyPath := tmpDir + "/oauth-key.bin"
@@ -44,20 +44,19 @@ func TestNewClientAppWithCustomScopes(t *testing.T) {
baseURL := "http://localhost:5000" baseURL := "http://localhost:5000"
scopes := []string{"atproto", "custom:scope"} scopes := []string{"atproto", "custom:scope"}
clientApp, err := NewClientApp(baseURL, store, scopes, keyPath, "AT Container Registry") app, err := NewAppWithScopes(baseURL, store, scopes, keyPath, "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewAppWithScopes() error = %v", err)
} }
if clientApp == nil { if app == nil {
t.Fatal("Expected non-nil clientApp") t.Fatal("Expected non-nil app")
} }
// Verify clientApp was created successfully // Verify scopes are set in config
// (Note: indigo's oauth.ClientApp doesn't expose scopes directly, config := app.GetConfig()
// but we can verify it was created without error) if len(config.Scopes) != len(scopes) {
if clientApp.Dir == nil { t.Errorf("Expected %d scopes, got %d", len(scopes), len(config.Scopes))
t.Error("Expected directory to be set")
} }
} }
@@ -122,59 +121,3 @@ func TestScopesMatch(t *testing.T) {
}) })
} }
} }
// ----------------------------------------------------------------------------
// Session Management (Refresher) Tests
// ----------------------------------------------------------------------------
func TestNewRefresher(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
refresher := NewRefresher(clientApp)
if refresher == nil {
t.Fatal("Expected non-nil refresher")
}
if refresher.clientApp == nil {
t.Error("Expected clientApp to be set")
}
}
func TestRefresher_SetUISessionStore(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
refresher := NewRefresher(clientApp)
// Test that SetUISessionStore doesn't panic with nil
// Full mock implementation requires implementing the interface
refresher.SetUISessionStore(nil)
// Verify nil is accepted
if refresher.uiSessionStore != nil {
t.Error("Expected UI session store to be nil after setting nil")
}
}

View File

@@ -13,7 +13,7 @@ import (
type InteractiveResult struct { type InteractiveResult struct {
SessionData *oauth.ClientSessionData SessionData *oauth.ClientSessionData
Session *oauth.ClientSession Session *oauth.ClientSession
ClientApp *oauth.ClientApp App *App
} }
// InteractiveFlowWithCallback runs an interactive OAuth flow with explicit callback handling // InteractiveFlowWithCallback runs an interactive OAuth flow with explicit callback handling
@@ -32,16 +32,19 @@ func InteractiveFlowWithCallback(
return nil, fmt.Errorf("failed to create OAuth store: %w", err) return nil, fmt.Errorf("failed to create OAuth store: %w", err)
} }
// Create OAuth client app with custom scopes (or defaults if nil) // Create OAuth app with custom scopes (or defaults if nil)
// Interactive flows are typically for production use (credential helper, etc.) // Interactive flows are typically for production use (credential helper, etc.)
// so we default to testMode=false
// For CLI tools, we use an empty keyPath since they're typically localhost (public client) // For CLI tools, we use an empty keyPath since they're typically localhost (public client)
// or ephemeral sessions // or ephemeral sessions
if scopes == nil { var app *App
scopes = GetDefaultScopes("*") if scopes != nil {
app, err = NewAppWithScopes(baseURL, store, scopes, "", "AT Container Registry")
} else {
app, err = NewApp(baseURL, store, "*", "", "AT Container Registry")
} }
clientApp, err := NewClientApp(baseURL, store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create OAuth client app: %w", err) return nil, fmt.Errorf("failed to create OAuth app: %w", err)
} }
// Channel to receive callback result // Channel to receive callback result
@@ -51,7 +54,7 @@ func InteractiveFlowWithCallback(
// Create callback handler // Create callback handler
callbackHandler := func(w http.ResponseWriter, r *http.Request) { callbackHandler := func(w http.ResponseWriter, r *http.Request) {
// Process callback // Process callback
sessionData, err := clientApp.ProcessCallback(r.Context(), r.URL.Query()) sessionData, err := app.ProcessCallback(r.Context(), r.URL.Query())
if err != nil { if err != nil {
errorChan <- fmt.Errorf("failed to process callback: %w", err) errorChan <- fmt.Errorf("failed to process callback: %w", err)
http.Error(w, "OAuth callback failed", http.StatusInternalServerError) http.Error(w, "OAuth callback failed", http.StatusInternalServerError)
@@ -59,7 +62,7 @@ func InteractiveFlowWithCallback(
} }
// Resume session // Resume session
session, err := clientApp.ResumeSession(r.Context(), sessionData.AccountDID, sessionData.SessionID) session, err := app.ResumeSession(r.Context(), sessionData.AccountDID, sessionData.SessionID)
if err != nil { if err != nil {
errorChan <- fmt.Errorf("failed to resume session: %w", err) errorChan <- fmt.Errorf("failed to resume session: %w", err)
http.Error(w, "Failed to resume session", http.StatusInternalServerError) http.Error(w, "Failed to resume session", http.StatusInternalServerError)
@@ -70,7 +73,7 @@ func InteractiveFlowWithCallback(
resultChan <- &InteractiveResult{ resultChan <- &InteractiveResult{
SessionData: sessionData, SessionData: sessionData,
Session: session, Session: session,
ClientApp: clientApp, App: app,
} }
// Return success to browser // Return success to browser
@@ -84,7 +87,7 @@ func InteractiveFlowWithCallback(
} }
// Start auth flow // Start auth flow
authURL, err := clientApp.StartAuthFlow(ctx, handle) authURL, err := app.StartAuthFlow(ctx, handle)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to start auth flow: %w", err) return nil, fmt.Errorf("failed to start auth flow: %w", err)
} }

174
pkg/auth/oauth/refresher.go Normal file
View File

@@ -0,0 +1,174 @@
package oauth
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/bluesky-social/indigo/atproto/syntax"
)
// SessionCache represents a cached OAuth session
type SessionCache struct {
Session *oauth.ClientSession
SessionID string
}
// UISessionStore interface for managing UI sessions
// Shared between refresher and server
type UISessionStore interface {
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
DeleteByDID(did string)
}
// Refresher manages OAuth sessions and token refresh for AppView
type Refresher struct {
app *App
sessions map[string]*SessionCache // Key: DID string
mu sync.RWMutex
refreshLocks map[string]*sync.Mutex // Per-DID locks for refresh operations
refreshLockMu sync.Mutex // Protects refreshLocks map
uiSessionStore UISessionStore // For invalidating UI sessions on OAuth failures
}
// NewRefresher creates a new session refresher
func NewRefresher(app *App) *Refresher {
return &Refresher{
app: app,
sessions: make(map[string]*SessionCache),
refreshLocks: make(map[string]*sync.Mutex),
}
}
// SetUISessionStore sets the UI session store for invalidating sessions on OAuth failures
func (r *Refresher) SetUISessionStore(store UISessionStore) {
r.uiSessionStore = store
}
// GetSession gets a fresh OAuth session for a DID
// Returns cached session if still valid, otherwise resumes from store
func (r *Refresher) GetSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
// Check cache first (fast path)
r.mu.RLock()
cached, ok := r.sessions[did]
r.mu.RUnlock()
if ok && cached.Session != nil {
// Session cached, tokens will auto-refresh if needed
return cached.Session, nil
}
// Session not cached, need to resume from store
// Get or create per-DID lock to prevent concurrent resume operations
r.refreshLockMu.Lock()
didLock, ok := r.refreshLocks[did]
if !ok {
didLock = &sync.Mutex{}
r.refreshLocks[did] = didLock
}
r.refreshLockMu.Unlock()
// Acquire DID-specific lock
didLock.Lock()
defer didLock.Unlock()
// Double-check cache after acquiring lock (another goroutine might have loaded it)
r.mu.RLock()
cached, ok = r.sessions[did]
r.mu.RUnlock()
if ok && cached.Session != nil {
return cached.Session, nil
}
// Actually resume the session
return r.resumeSession(ctx, did)
}
// resumeSession loads a session from storage and caches it
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
// Parse DID
accountDID, err := syntax.ParseDID(did)
if err != nil {
return nil, fmt.Errorf("failed to parse DID: %w", err)
}
// Get the latest session for this DID from SQLite store
// The store must implement GetLatestSessionForDID (returns newest by updated_at)
type sessionGetter interface {
GetLatestSessionForDID(ctx context.Context, did string) (*oauth.ClientSessionData, string, error)
}
getter, ok := r.app.clientApp.Store.(sessionGetter)
if !ok {
return nil, fmt.Errorf("store must implement GetLatestSessionForDID (SQLite store required)")
}
sessionData, sessionID, err := getter.GetLatestSessionForDID(ctx, did)
if err != nil {
return nil, fmt.Errorf("no session found for DID: %s", did)
}
// Validate that session scopes match current desired scopes
desiredScopes := r.app.GetConfig().Scopes
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
slog.Debug("Scope mismatch, deleting session",
"did", did,
"storedScopes", sessionData.Scopes,
"desiredScopes", desiredScopes)
// Delete the session from database since scopes have changed
if err := r.app.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
}
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
}
// Resume session
session, err := r.app.ResumeSession(ctx, accountDID, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to resume session: %w", err)
}
// Cache the session
r.mu.Lock()
r.sessions[did] = &SessionCache{
Session: session,
SessionID: sessionID,
}
r.mu.Unlock()
return session, nil
}
// InvalidateSession removes a cached session for a DID
// This is useful when a new OAuth flow creates a fresh session or when OAuth refresh fails
// Also invalidates any UI sessions for this DID to force re-authentication
func (r *Refresher) InvalidateSession(did string) {
r.mu.Lock()
delete(r.sessions, did)
r.mu.Unlock()
// Also delete UI sessions to force user to re-authenticate
if r.uiSessionStore != nil {
r.uiSessionStore.DeleteByDID(did)
}
}
// GetSessionID returns the sessionID for a cached session
// Returns empty string if session not cached
func (r *Refresher) GetSessionID(did string) string {
r.mu.RLock()
defer r.mu.RUnlock()
cached, ok := r.sessions[did]
if !ok || cached == nil {
return ""
}
return cached.SessionID
}

View File

@@ -0,0 +1,66 @@
package oauth
import (
"testing"
)
func TestNewRefresher(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
if err != nil {
t.Fatalf("NewApp() error = %v", err)
}
refresher := NewRefresher(app)
if refresher == nil {
t.Fatal("Expected non-nil refresher")
}
if refresher.app == nil {
t.Error("Expected app to be set")
}
if refresher.sessions == nil {
t.Error("Expected sessions map to be initialized")
}
if refresher.refreshLocks == nil {
t.Error("Expected refreshLocks map to be initialized")
}
}
func TestRefresher_SetUISessionStore(t *testing.T) {
tmpDir := t.TempDir()
storePath := tmpDir + "/oauth-test.json"
store, err := NewFileStore(storePath)
if err != nil {
t.Fatalf("NewFileStore() error = %v", err)
}
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
if err != nil {
t.Fatalf("NewApp() error = %v", err)
}
refresher := NewRefresher(app)
// Test that SetUISessionStore doesn't panic with nil
// Full mock implementation requires implementing the interface
refresher.SetUISessionStore(nil)
// Verify nil is accepted
if refresher.uiSessionStore != nil {
t.Error("Expected UI session store to be nil after setting nil")
}
}
// Note: Full session management tests will be added in comprehensive implementation
// Those tests will require mocking OAuth sessions and testing cache behavior

View File

@@ -10,11 +10,10 @@ import (
"time" "time"
"atcr.io/pkg/atproto" "atcr.io/pkg/atproto"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
) )
// UISessionStore is the interface for UI session management // UISessionStore is the interface for UI session management
// UISessionStore is defined in client.go (session management section) // UISessionStore is defined in refresher.go to avoid duplication
// UserStore is the interface for user management // UserStore is the interface for user management
type UserStore interface { type UserStore interface {
@@ -29,16 +28,16 @@ type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, sessio
// Server handles OAuth authorization for the AppView // Server handles OAuth authorization for the AppView
type Server struct { type Server struct {
clientApp *oauth.ClientApp app *App
refresher *Refresher refresher *Refresher
uiSessionStore UISessionStore uiSessionStore UISessionStore
postAuthCallback PostAuthCallback postAuthCallback PostAuthCallback
} }
// NewServer creates a new OAuth server // NewServer creates a new OAuth server
func NewServer(clientApp *oauth.ClientApp) *Server { func NewServer(app *App) *Server {
return &Server{ return &Server{
clientApp: clientApp, app: app,
} }
} }
@@ -75,7 +74,7 @@ func (s *Server) ServeAuthorize(w http.ResponseWriter, r *http.Request) {
slog.Debug("Starting OAuth flow", "handle", handle) slog.Debug("Starting OAuth flow", "handle", handle)
// Start auth flow via indigo // Start auth flow via indigo
authURL, err := s.clientApp.StartAuthFlow(r.Context(), handle) authURL, err := s.app.StartAuthFlow(r.Context(), handle)
if err != nil { if err != nil {
slog.Error("Failed to start auth flow", "error", err, "handle", handle) slog.Error("Failed to start auth flow", "error", err, "handle", handle)
@@ -112,7 +111,7 @@ func (s *Server) ServeCallback(w http.ResponseWriter, r *http.Request) {
} }
// Process OAuth callback via indigo (handles state validation internally) // Process OAuth callback via indigo (handles state validation internally)
sessionData, err := s.clientApp.ProcessCallback(r.Context(), r.URL.Query()) sessionData, err := s.app.ProcessCallback(r.Context(), r.URL.Query())
if err != nil { if err != nil {
s.renderError(w, fmt.Sprintf("Failed to process OAuth callback: %v", err)) s.renderError(w, fmt.Sprintf("Failed to process OAuth callback: %v", err))
return return
@@ -123,20 +122,10 @@ func (s *Server) ServeCallback(w http.ResponseWriter, r *http.Request) {
slog.Debug("OAuth callback successful", "did", did, "sessionID", sessionID) slog.Debug("OAuth callback successful", "did", did, "sessionID", sessionID)
// Clean up old OAuth sessions for this DID BEFORE invalidating cache // Invalidate cached session (if any) since we have a new session with new tokens
// This prevents accumulation of stale sessions with expired refresh tokens if s.refresher != nil {
// Order matters: delete from DB first, then invalidate cache, so when cache reloads s.refresher.InvalidateSession(did)
// it will only find the new session slog.Debug("Invalidated cached session after creating new session", "did", did)
type sessionCleaner interface {
DeleteOldSessionsForDID(ctx context.Context, did string, keepSessionID string) error
}
if cleaner, ok := s.clientApp.Store.(sessionCleaner); ok {
if err := cleaner.DeleteOldSessionsForDID(r.Context(), did, sessionID); err != nil {
slog.Warn("Failed to clean up old OAuth sessions", "did", did, "error", err)
// Non-fatal - log and continue
} else {
slog.Debug("Cleaned up old OAuth sessions", "did", did, "kept", sessionID)
}
} }
// Look up identity (resolve DID to handle) // Look up identity (resolve DID to handle)

View File

@@ -19,19 +19,18 @@ func TestNewServer(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
if server == nil { if server == nil {
t.Fatal("Expected non-nil server") t.Fatal("Expected non-nil server")
} }
if server.clientApp == nil { if server.app == nil {
t.Error("Expected clientApp to be set") t.Error("Expected app to be set")
} }
} }
@@ -44,14 +43,13 @@ func TestServer_SetRefresher(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
refresher := NewRefresher(clientApp) refresher := NewRefresher(app)
server.SetRefresher(refresher) server.SetRefresher(refresher)
if server.refresher == nil { if server.refresher == nil {
@@ -68,13 +66,12 @@ func TestServer_SetPostAuthCallback(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
// Set callback with correct signature // Set callback with correct signature
server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error { server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error {
@@ -95,13 +92,12 @@ func TestServer_SetUISessionStore(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
mockStore := &mockUISessionStore{} mockStore := &mockUISessionStore{}
server.SetUISessionStore(mockStore) server.SetUISessionStore(mockStore)
@@ -159,13 +155,12 @@ func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil) req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -187,13 +182,12 @@ func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil) req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -217,13 +211,12 @@ func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil) req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -245,13 +238,12 @@ func TestServer_ServeCallback_OAuthError(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil) req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -278,13 +270,12 @@ func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
callbackInvoked := false callbackInvoked := false
server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error { server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error {
@@ -323,13 +314,12 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
server.SetUISessionStore(uiStore) server.SetUISessionStore(uiStore)
// Verify UI session store is set // Verify UI session store is set
@@ -353,13 +343,12 @@ func TestServer_RenderError(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
w := httptest.NewRecorder() w := httptest.NewRecorder()
server.renderError(w, "Test error message") server.renderError(w, "Test error message")
@@ -388,13 +377,12 @@ func TestServer_RenderRedirectToSettings(t *testing.T) {
t.Fatalf("NewFileStore() error = %v", err) t.Fatalf("NewFileStore() error = %v", err)
} }
scopes := GetDefaultScopes("*") app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil { if err != nil {
t.Fatalf("NewClientApp() error = %v", err) t.Fatalf("NewApp() error = %v", err)
} }
server := NewServer(clientApp) server := NewServer(app)
w := httptest.NewRecorder() w := httptest.NewRecorder()
server.renderRedirectToSettings(w, "alice.bsky.social") server.renderRedirectToSettings(w, "alice.bsky.social")

View File

@@ -46,7 +46,8 @@ func GetOrFetchServiceToken(
session, err := refresher.GetSession(ctx, did) session, err := refresher.GetSession(ctx, did)
if err != nil { if err != nil {
// OAuth session unavailable - fail // OAuth session unavailable - invalidate and fail
refresher.InvalidateSession(did)
InvalidateServiceToken(did, holdDID) InvalidateServiceToken(did, holdDID)
return "", fmt.Errorf("failed to get OAuth session: %w", err) return "", fmt.Errorf("failed to get OAuth session: %w", err)
} }
@@ -72,15 +73,17 @@ func GetOrFetchServiceToken(
// Use OAuth session to authenticate to PDS (with DPoP) // Use OAuth session to authenticate to PDS (with DPoP)
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth") resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
if err != nil { if err != nil {
// Auth error - may indicate expired tokens or corrupted session // Invalidate session on auth errors (may indicate corrupted session or expired tokens)
refresher.InvalidateSession(did)
InvalidateServiceToken(did, holdDID) InvalidateServiceToken(did, holdDID)
return "", fmt.Errorf("OAuth validation failed: %w", err) return "", fmt.Errorf("OAuth validation failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
// Service auth failed // Invalidate session on auth failures
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
refresher.InvalidateSession(did)
InvalidateServiceToken(did, holdDID) InvalidateServiceToken(did, holdDID)
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes)) return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
} }