mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-23 09:50:33 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd5d2aab55 |
@@ -1,23 +0,0 @@
|
|||||||
when:
|
|
||||||
- event: ["push"]
|
|
||||||
branch: ["*"]
|
|
||||||
- event: ["pull_request"]
|
|
||||||
branch: ["main"]
|
|
||||||
|
|
||||||
engine: kubernetes
|
|
||||||
image: golang:1.24-bookworm
|
|
||||||
architecture: amd64
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download and Generate
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: 1
|
|
||||||
command: |
|
|
||||||
go mod download
|
|
||||||
go generate ./...
|
|
||||||
|
|
||||||
- name: Run Tests
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: 1
|
|
||||||
command: |
|
|
||||||
go test -cover ./...
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
when:
|
|
||||||
- event: ["push"]
|
|
||||||
branch: ["*"]
|
|
||||||
- event: ["pull_request"]
|
|
||||||
branch: ["main"]
|
|
||||||
|
|
||||||
engine: kubernetes
|
|
||||||
image: golang:1.24-bookworm
|
|
||||||
architecture: arm64
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download and Generate
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: 1
|
|
||||||
command: |
|
|
||||||
go mod download
|
|
||||||
go generate ./...
|
|
||||||
|
|
||||||
- name: Run Tests
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: 1
|
|
||||||
command: |
|
|
||||||
go test -cover ./...
|
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
# Tangled Workflow: Release Credential Helper
|
# Tangled Workflow: Release Credential Helper to Tangled.org
|
||||||
#
|
#
|
||||||
# This workflow builds cross-platform binaries for the credential helper.
|
# This workflow builds the docker-credential-atcr binary and publishes it
|
||||||
# Creates tarballs for curl/bash installation and provides instructions
|
# to Tangled.org for distribution via Homebrew.
|
||||||
# for updating the Homebrew formula.
|
|
||||||
#
|
#
|
||||||
# Triggers on version tags (v*) pushed to the repository.
|
# Current limitation: Tangled doesn't support triggering on tags yet,
|
||||||
|
# so this triggers on push to main. Manually verify you've tagged the
|
||||||
|
# release before pushing.
|
||||||
|
|
||||||
when:
|
when:
|
||||||
- event: ["manual"]
|
- event: ["push"]
|
||||||
tag: ["v*"]
|
tag: ["v*"]
|
||||||
|
|
||||||
engine: "nixery"
|
engine: "nixery"
|
||||||
@@ -15,36 +16,20 @@ engine: "nixery"
|
|||||||
dependencies:
|
dependencies:
|
||||||
nixpkgs:
|
nixpkgs:
|
||||||
- go_1_24 # Go 1.24+ for building
|
- go_1_24 # Go 1.24+ for building
|
||||||
|
- git # For finding tags
|
||||||
- goreleaser # For building multi-platform binaries
|
- goreleaser # For building multi-platform binaries
|
||||||
- curl # Required by go generate for downloading vendor assets
|
# - goat # TODO: Add goat CLI for uploading to Tangled (if available in nixpkgs)
|
||||||
- gnugrep # Required for tag detection
|
|
||||||
- gnutar # Required for creating tarballs
|
|
||||||
- gzip # Required for compressing tarballs
|
|
||||||
- coreutils # Required for sha256sum
|
|
||||||
|
|
||||||
environment:
|
environment:
|
||||||
CGO_ENABLED: "0" # Build static binaries
|
CGO_ENABLED: "0" # Build static binaries
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Get tag for current commit
|
- name: Find latest git tag
|
||||||
command: |
|
command: |
|
||||||
# Fetch tags (shallow clone doesn't include them by default)
|
# Get the most recent version tag
|
||||||
git fetch --tags
|
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.1")
|
||||||
|
echo "Latest tag: $LATEST_TAG"
|
||||||
# Find the tag that points to the current commit
|
echo "$LATEST_TAG" > .version
|
||||||
TAG=$(git tag --points-at HEAD | grep -E '^v[0-9]' | head -n1)
|
|
||||||
|
|
||||||
if [ -z "$TAG" ]; then
|
|
||||||
echo "Error: No version tag found for current commit"
|
|
||||||
echo "Available tags:"
|
|
||||||
git tag
|
|
||||||
echo "Current commit:"
|
|
||||||
git rev-parse HEAD
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Building version: $TAG"
|
|
||||||
echo "$TAG" > .version
|
|
||||||
|
|
||||||
# Also get the commit hash for reference
|
# Also get the commit hash for reference
|
||||||
COMMIT_HASH=$(git rev-parse HEAD)
|
COMMIT_HASH=$(git rev-parse HEAD)
|
||||||
@@ -52,20 +37,17 @@ steps:
|
|||||||
|
|
||||||
- name: Build binaries with GoReleaser
|
- name: Build binaries with GoReleaser
|
||||||
command: |
|
command: |
|
||||||
|
# Read version from previous step
|
||||||
VERSION=$(cat .version)
|
VERSION=$(cat .version)
|
||||||
export VERSION
|
export VERSION
|
||||||
|
|
||||||
# Build for all platforms using GoReleaser
|
# Build for all platforms using GoReleaser
|
||||||
|
# This creates artifacts in dist/ directory
|
||||||
goreleaser build --clean --snapshot --config .goreleaser.yaml
|
goreleaser build --clean --snapshot --config .goreleaser.yaml
|
||||||
|
|
||||||
# List what was built
|
# List what was built
|
||||||
echo "Built artifacts:"
|
echo "Built artifacts:"
|
||||||
if [ -d "dist" ]; then
|
ls -lh dist/
|
||||||
ls -lh dist/
|
|
||||||
else
|
|
||||||
echo "Error: dist/ directory was not created by GoReleaser"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Package artifacts
|
- name: Package artifacts
|
||||||
command: |
|
command: |
|
||||||
@@ -74,82 +56,82 @@ steps:
|
|||||||
|
|
||||||
cd dist
|
cd dist
|
||||||
|
|
||||||
# Create tarballs for each platform
|
# Create tarballs for each platform (GoReleaser might already do this)
|
||||||
# GoReleaser creates directories like: credential-helper_{os}_{arch}_v{goversion}
|
|
||||||
|
|
||||||
# Darwin x86_64
|
# Darwin x86_64
|
||||||
if [ -d "credential-helper_darwin_amd64_v1" ]; then
|
if [ -d "docker-credential-atcr_darwin_amd64_v1" ]; then
|
||||||
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz" \
|
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz" \
|
||||||
-C credential-helper_darwin_amd64_v1 docker-credential-atcr
|
-C docker-credential-atcr_darwin_amd64_v1 docker-credential-atcr
|
||||||
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Darwin_x86_64.tar.gz"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Darwin arm64
|
# Darwin arm64
|
||||||
for dir in credential-helper_darwin_arm64*; do
|
if [ -d "docker-credential-atcr_darwin_arm64" ]; then
|
||||||
if [ -d "$dir" ]; then
|
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz" \
|
||||||
tar czf "docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz" \
|
-C docker-credential-atcr_darwin_arm64 docker-credential-atcr
|
||||||
-C "$dir" docker-credential-atcr
|
fi
|
||||||
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Darwin_arm64.tar.gz"
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Linux x86_64
|
# Linux x86_64
|
||||||
if [ -d "credential-helper_linux_amd64_v1" ]; then
|
if [ -d "docker-credential-atcr_linux_amd64_v1" ]; then
|
||||||
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz" \
|
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz" \
|
||||||
-C credential-helper_linux_amd64_v1 docker-credential-atcr
|
-C docker-credential-atcr_linux_amd64_v1 docker-credential-atcr
|
||||||
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Linux_x86_64.tar.gz"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Linux arm64
|
# Linux arm64
|
||||||
for dir in credential-helper_linux_arm64*; do
|
if [ -d "docker-credential-atcr_linux_arm64" ]; then
|
||||||
if [ -d "$dir" ]; then
|
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz" \
|
||||||
tar czf "docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz" \
|
-C docker-credential-atcr_linux_arm64 docker-credential-atcr
|
||||||
-C "$dir" docker-credential-atcr
|
fi
|
||||||
echo "Created: docker-credential-atcr_${VERSION_NO_V}_Linux_arm64.tar.gz"
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
|
echo "Created tarballs:"
|
||||||
|
ls -lh *.tar.gz
|
||||||
|
|
||||||
|
- name: Upload to Tangled.org
|
||||||
|
command: |
|
||||||
|
VERSION=$(cat .version)
|
||||||
|
VERSION_NO_V=${VERSION#v}
|
||||||
|
|
||||||
|
# TODO: Authenticate with goat CLI
|
||||||
|
# You'll need to set up credentials/tokens for goat
|
||||||
|
# Example (adjust based on goat's actual auth mechanism):
|
||||||
|
# goat login --pds https://your-pds.example.com --handle your.handle
|
||||||
|
|
||||||
|
# TODO: Upload each artifact to Tangled.org
|
||||||
|
# This creates sh.tangled.repo.artifact records in your ATProto PDS
|
||||||
|
# Adjust these commands based on scripts/publish-artifact.sh pattern
|
||||||
|
|
||||||
|
# Example structure (you'll need to fill in actual goat commands):
|
||||||
|
# for artifact in dist/*.tar.gz; do
|
||||||
|
# echo "Uploading $artifact..."
|
||||||
|
# goat upload \
|
||||||
|
# --repo "at-container-registry" \
|
||||||
|
# --tag "$VERSION" \
|
||||||
|
# --file "$artifact"
|
||||||
|
# done
|
||||||
|
|
||||||
|
echo "TODO: Implement goat upload commands"
|
||||||
|
echo "See scripts/publish-artifact.sh for reference"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Tarballs ready:"
|
echo "After uploading, you'll receive a TAG_HASH from Tangled."
|
||||||
ls -lh *.tar.gz 2>/dev/null || echo "Warning: No tarballs created"
|
echo "Update Formula/docker-credential-atcr.rb with:"
|
||||||
|
echo " VERSION = \"$VERSION_NO_V\""
|
||||||
|
echo " TAG_HASH = \"<hash-from-tangled>\""
|
||||||
|
echo ""
|
||||||
|
echo "Then run: scripts/update-homebrew-formula.sh $VERSION_NO_V <tag-hash>"
|
||||||
|
|
||||||
- name: Generate checksums
|
- name: Generate checksums for verification
|
||||||
command: |
|
command: |
|
||||||
VERSION=$(cat .version)
|
VERSION=$(cat .version)
|
||||||
VERSION_NO_V=${VERSION#v}
|
VERSION_NO_V=${VERSION#v}
|
||||||
|
|
||||||
cd dist
|
cd dist
|
||||||
|
|
||||||
echo ""
|
echo "SHA256 checksums for Homebrew formula:"
|
||||||
echo "=========================================="
|
echo "======================================="
|
||||||
echo "SHA256 Checksums"
|
|
||||||
echo "=========================================="
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Generate checksums file
|
for file in docker-credential-atcr_${VERSION_NO_V}_*.tar.gz; do
|
||||||
sha256sum docker-credential-atcr_${VERSION_NO_V}_*.tar.gz 2>/dev/null | tee checksums.txt || echo "No checksums generated"
|
if [ -f "$file" ]; then
|
||||||
|
sha256sum "$file"
|
||||||
- name: Next steps
|
fi
|
||||||
command: |
|
done
|
||||||
VERSION=$(cat .version)
|
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=========================================="
|
echo "Copy these checksums to Formula/docker-credential-atcr.rb"
|
||||||
echo "Release $VERSION is ready!"
|
|
||||||
echo "=========================================="
|
|
||||||
echo ""
|
|
||||||
echo "Distribution tarballs are in: dist/"
|
|
||||||
echo ""
|
|
||||||
echo "Next steps:"
|
|
||||||
echo ""
|
|
||||||
echo "1. Upload tarballs to your hosting/CDN (or GitHub releases)"
|
|
||||||
echo ""
|
|
||||||
echo "2. For Homebrew users, update the formula:"
|
|
||||||
echo " ./scripts/update-homebrew-formula.sh $VERSION"
|
|
||||||
echo " # Then update Formula/docker-credential-atcr.rb and push to homebrew-tap"
|
|
||||||
echo ""
|
|
||||||
echo "3. For curl/bash installation, users can download directly:"
|
|
||||||
echo " curl -L <your-cdn>/docker-credential-atcr_<version>_<os>_<arch>.tar.gz | tar xz"
|
|
||||||
echo " sudo mv docker-credential-atcr /usr/local/bin/"
|
|
||||||
|
|||||||
@@ -2,55 +2,38 @@
|
|||||||
# Triggers on version tags and builds cross-platform binaries using buildah
|
# Triggers on version tags and builds cross-platform binaries using buildah
|
||||||
|
|
||||||
when:
|
when:
|
||||||
- event: ["push"]
|
- event: ["manual"]
|
||||||
tag: ["v*"]
|
# TODO: Trigger only on version tags (v1.0.0, v2.1.3, etc.)
|
||||||
|
branch: ["main"]
|
||||||
|
|
||||||
engine: "buildah"
|
engine: "nixery"
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
nixpkgs:
|
||||||
|
- buildah
|
||||||
|
- chroot
|
||||||
|
|
||||||
environment:
|
environment:
|
||||||
IMAGE_REGISTRY: atcr.io
|
IMAGE_REGISTRY: atcr.io
|
||||||
IMAGE_USER: evan.jarrett.net
|
IMAGE_USER: evan.jarrett.net
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Get tag for current commit
|
- name: Setup build environment
|
||||||
command: |
|
command: |
|
||||||
#test
|
if ! grep -q "^root:" /etc/passwd 2>/dev/null; then
|
||||||
# Fetch tags (shallow clone doesn't include them by default)
|
echo "root:x:0:0:root:/root:/bin/sh" >> /etc/passwd
|
||||||
git fetch --tags
|
|
||||||
|
|
||||||
# Find the tag that points to the current commit
|
|
||||||
TAG=$(git tag --points-at HEAD | grep -E '^v[0-9]' | head -n1)
|
|
||||||
|
|
||||||
if [ -z "$TAG" ]; then
|
|
||||||
echo "Error: No version tag found for current commit"
|
|
||||||
echo "Available tags:"
|
|
||||||
git tag
|
|
||||||
echo "Current commit:"
|
|
||||||
git rev-parse HEAD
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Building version: $TAG"
|
- name: Login to registry
|
||||||
echo "$TAG" > .version
|
|
||||||
|
|
||||||
- name: Setup registry credentials
|
|
||||||
command: |
|
command: |
|
||||||
mkdir -p ~/.docker
|
echo "${APP_PASSWORD}" | buildah login \
|
||||||
cat > ~/.docker/config.json <<EOF
|
--storage-driver vfs \
|
||||||
{
|
-u "${IMAGE_USER}" \
|
||||||
"auths": {
|
--password-stdin \
|
||||||
"${IMAGE_REGISTRY}": {
|
${IMAGE_REGISTRY}
|
||||||
"auth": "$(echo -n "${IMAGE_USER}:${APP_PASSWORD}" | base64)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
chmod 600 ~/.docker/config.json
|
|
||||||
|
|
||||||
- name: Build and push AppView image
|
- name: Build and push AppView image
|
||||||
command: |
|
command: |
|
||||||
TAG=$(cat .version)
|
|
||||||
|
|
||||||
buildah bud \
|
buildah bud \
|
||||||
--storage-driver vfs \
|
--storage-driver vfs \
|
||||||
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG} \
|
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG} \
|
||||||
@@ -58,18 +41,12 @@ steps:
|
|||||||
--file ./Dockerfile.appview \
|
--file ./Dockerfile.appview \
|
||||||
.
|
.
|
||||||
|
|
||||||
buildah push \
|
|
||||||
--storage-driver vfs \
|
|
||||||
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG}
|
|
||||||
|
|
||||||
buildah push \
|
buildah push \
|
||||||
--storage-driver vfs \
|
--storage-driver vfs \
|
||||||
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:latest
|
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:latest
|
||||||
|
|
||||||
- name: Build and push Hold image
|
- name: Build and push Hold image
|
||||||
command: |
|
command: |
|
||||||
TAG=$(cat .version)
|
|
||||||
|
|
||||||
buildah bud \
|
buildah bud \
|
||||||
--storage-driver vfs \
|
--storage-driver vfs \
|
||||||
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG} \
|
--tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG} \
|
||||||
@@ -77,10 +54,6 @@ steps:
|
|||||||
--file ./Dockerfile.hold \
|
--file ./Dockerfile.hold \
|
||||||
.
|
.
|
||||||
|
|
||||||
buildah push \
|
|
||||||
--storage-driver vfs \
|
|
||||||
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG}
|
|
||||||
|
|
||||||
buildah push \
|
buildah push \
|
||||||
--storage-driver vfs \
|
--storage-driver vfs \
|
||||||
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:latest
|
${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:latest
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
when:
|
when:
|
||||||
- event: ["push"]
|
- event: ["push"]
|
||||||
branch: ["main", "test"]
|
branch: ["main"]
|
||||||
|
- event: ["pull_request"]
|
||||||
|
branch: ["main"]
|
||||||
|
|
||||||
engine: "nixery"
|
engine: "nixery"
|
||||||
|
|
||||||
|
|||||||
110
CLAUDE.md
110
CLAUDE.md
@@ -206,62 +206,9 @@ ATCR uses middleware and routing to handle requests:
|
|||||||
- Implements `distribution.Repository`
|
- Implements `distribution.Repository`
|
||||||
- Returns custom `Manifests()` and `Blobs()` implementations
|
- Returns custom `Manifests()` and `Blobs()` implementations
|
||||||
- Routes manifests to ATProto, blobs to S3 or BYOS
|
- Routes manifests to ATProto, blobs to S3 or BYOS
|
||||||
- **IMPORTANT**: RoutingRepository is created fresh on EVERY request (no caching)
|
|
||||||
- Each Docker layer upload is a separate HTTP request (possibly different process)
|
|
||||||
- OAuth sessions can be refreshed/invalidated between requests
|
|
||||||
- The OAuth refresher already caches sessions efficiently (in-memory + DB)
|
|
||||||
- Previous caching of repositories with stale ATProtoClient caused "invalid refresh token" errors
|
|
||||||
|
|
||||||
### Authentication Architecture
|
### Authentication Architecture
|
||||||
|
|
||||||
#### Token Types and Flows
|
|
||||||
|
|
||||||
ATCR uses three distinct token types in its authentication flow:
|
|
||||||
|
|
||||||
**1. OAuth Tokens (Access + Refresh)**
|
|
||||||
- **Issued by:** User's PDS via OAuth flow
|
|
||||||
- **Stored in:** AppView database (`oauth_sessions` table)
|
|
||||||
- **Cached in:** Refresher's in-memory map (per-DID)
|
|
||||||
- **Used for:** AppView → User's PDS communication (write manifests, read profiles)
|
|
||||||
- **Managed by:** Indigo library with DPoP (automatic refresh)
|
|
||||||
- **Lifetime:** Access ~2 hours, Refresh ~90 days (PDS controlled)
|
|
||||||
|
|
||||||
**2. Registry JWTs**
|
|
||||||
- **Issued by:** AppView after OAuth login
|
|
||||||
- **Stored in:** Docker credential helper (`~/.atcr/credential-helper-token.json`)
|
|
||||||
- **Used for:** Docker client → AppView authentication
|
|
||||||
- **Lifetime:** 15 minutes (configurable via `ATCR_TOKEN_EXPIRATION`)
|
|
||||||
- **Format:** JWT with DID claim
|
|
||||||
|
|
||||||
**3. Service Tokens**
|
|
||||||
- **Issued by:** User's PDS via `com.atproto.server.getServiceAuth`
|
|
||||||
- **Stored in:** AppView memory (in-memory cache with ~50s TTL)
|
|
||||||
- **Used for:** AppView → Hold service authentication (acting on behalf of user)
|
|
||||||
- **Lifetime:** 60 seconds (PDS controlled), cached for 50s
|
|
||||||
- **Required:** OAuth session to obtain (catch-22 solved by Refresher)
|
|
||||||
|
|
||||||
**Token Flow Diagram:**
|
|
||||||
```
|
|
||||||
┌─────────────┐ ┌──────────────┐
|
|
||||||
│ Docker │ ─── Registry JWT ──────────────→ │ AppView │
|
|
||||||
│ Client │ │ │
|
|
||||||
└─────────────┘ └──────┬───────┘
|
|
||||||
│
|
|
||||||
│ OAuth tokens
|
|
||||||
│ (access + refresh)
|
|
||||||
↓
|
|
||||||
┌──────────────┐
|
|
||||||
│ User's PDS │
|
|
||||||
└──────┬───────┘
|
|
||||||
│
|
|
||||||
│ Service token
|
|
||||||
│ (via getServiceAuth)
|
|
||||||
↓
|
|
||||||
┌──────────────┐
|
|
||||||
│ Hold Service │
|
|
||||||
└──────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
#### ATProto OAuth with DPoP
|
#### ATProto OAuth with DPoP
|
||||||
|
|
||||||
ATCR implements the full ATProto OAuth specification with mandatory security features:
|
ATCR implements the full ATProto OAuth specification with mandatory security features:
|
||||||
@@ -273,22 +220,13 @@ ATCR implements the full ATProto OAuth specification with mandatory security fea
|
|||||||
|
|
||||||
**Key Components** (`pkg/auth/oauth/`):
|
**Key Components** (`pkg/auth/oauth/`):
|
||||||
|
|
||||||
1. **Client** (`client.go`) - OAuth client configuration and session management
|
1. **Client** (`client.go`) - Core OAuth client with encapsulated configuration
|
||||||
- **ClientApp setup:**
|
- Uses indigo's `NewLocalhostConfig()` for localhost (public client)
|
||||||
- `NewClientApp()` - Creates configured `*oauth.ClientApp` (uses indigo directly, no wrapper)
|
- Uses `NewPublicConfig()` for production base (upgraded to confidential if key provided)
|
||||||
- Uses `NewLocalhostConfig()` for localhost (public client)
|
- `RedirectURI()` - returns `baseURL + "/auth/oauth/callback"`
|
||||||
- Uses `NewPublicConfig()` for production (upgraded to confidential with P-256 key)
|
- `GetDefaultScopes()` - returns ATCR registry scopes
|
||||||
- `GetDefaultScopes()` - Returns ATCR-specific OAuth scopes
|
- `GetConfigRef()` - returns mutable config for `SetClientSecret()` calls
|
||||||
- `ScopesMatch()` - Compares scope lists (order-independent)
|
- All OAuth flows (authorization, token exchange, refresh) in one place
|
||||||
- **Session management (Refresher):**
|
|
||||||
- `NewRefresher()` - Creates session cache manager for AppView
|
|
||||||
- **Purpose:** In-memory cache for `*oauth.ClientSession` objects (performance optimization)
|
|
||||||
- **Why needed:** Saves 1-2 DB queries per request (~2ms) with minimal code complexity
|
|
||||||
- Per-DID locking prevents concurrent database loads
|
|
||||||
- Calls `ClientApp.ResumeSession()` on cache miss
|
|
||||||
- Indigo handles token refresh automatically (transparent to ATCR)
|
|
||||||
- **Performance:** Essential for high-traffic deployments, negligible for low-traffic
|
|
||||||
- **Architecture:** Single file containing both ClientApp helpers and Refresher (combined from previous two-file structure)
|
|
||||||
|
|
||||||
2. **Keys** (`keys.go`) - P-256 key management for confidential clients
|
2. **Keys** (`keys.go`) - P-256 key management for confidential clients
|
||||||
- `GenerateOrLoadClientKey()` - generates or loads P-256 key from disk
|
- `GenerateOrLoadClientKey()` - generates or loads P-256 key from disk
|
||||||
@@ -297,17 +235,21 @@ ATCR implements the full ATProto OAuth specification with mandatory security fea
|
|||||||
- `PrivateKeyToMultibase()` - converts key for `SetClientSecret()` API
|
- `PrivateKeyToMultibase()` - converts key for `SetClientSecret()` API
|
||||||
- **Key type:** P-256 (ES256) for OAuth standard compatibility (not K-256 like PDS keys)
|
- **Key type:** P-256 (ES256) for OAuth standard compatibility (not K-256 like PDS keys)
|
||||||
|
|
||||||
3. **Storage** - Persists OAuth sessions
|
3. **Token Storage** (`store.go`) - Persists OAuth sessions for AppView
|
||||||
- `db/oauth_store.go` - SQLite-backed storage for AppView (in UI database)
|
- SQLite-backed storage in UI database (not file-based)
|
||||||
- `store.go` - File-based storage for CLI tools (`~/.atcr/oauth-sessions.json`)
|
- Client uses `~/.atcr/oauth-token.json` (credential helper)
|
||||||
- Implements indigo's `ClientAuthStore` interface
|
|
||||||
|
|
||||||
4. **Server** (`server.go`) - OAuth authorization endpoints for AppView
|
4. **Refresher** (`refresher.go`) - Token refresh manager for AppView
|
||||||
|
- Caches OAuth sessions with automatic token refresh (handled by indigo library)
|
||||||
|
- Per-DID locking prevents concurrent refresh races
|
||||||
|
- Uses Client methods for consistency
|
||||||
|
|
||||||
|
5. **Server** (`server.go`) - OAuth authorization endpoints for AppView
|
||||||
- `GET /auth/oauth/authorize` - starts OAuth flow
|
- `GET /auth/oauth/authorize` - starts OAuth flow
|
||||||
- `GET /auth/oauth/callback` - handles OAuth callback
|
- `GET /auth/oauth/callback` - handles OAuth callback
|
||||||
- Uses `ClientApp` methods directly (no wrapper)
|
- Uses Client methods for authorization and token exchange
|
||||||
|
|
||||||
5. **Interactive Flow** (`interactive.go`) - Reusable OAuth flow for CLI tools
|
6. **Interactive Flow** (`interactive.go`) - Reusable OAuth flow for CLI tools
|
||||||
- Used by credential helper and hold service registration
|
- Used by credential helper and hold service registration
|
||||||
- Two-phase callback setup ensures PAR metadata availability
|
- Two-phase callback setup ensures PAR metadata availability
|
||||||
|
|
||||||
@@ -407,13 +349,12 @@ Later (subsequent docker push):
|
|||||||
- Implements `distribution.Repository` interface
|
- Implements `distribution.Repository` interface
|
||||||
- Uses RegistryContext to pass DID, PDS endpoint, hold DID, OAuth refresher, etc.
|
- Uses RegistryContext to pass DID, PDS endpoint, hold DID, OAuth refresher, etc.
|
||||||
|
|
||||||
**Database-based hold DID lookups**:
|
**hold_cache.go**: In-memory hold DID cache
|
||||||
- Queries SQLite `manifests` table for hold DID (indexed, fast)
|
- Caches `(DID, repository) → holdDid` for pull operations
|
||||||
- No in-memory caching needed - database IS the cache
|
- TTL: 10 minutes (covers typical pull operations)
|
||||||
- Persistent across restarts, multi-instance safe
|
- Cleanup: Background goroutine runs every 5 minutes
|
||||||
- Pull operations use hold DID from latest manifest (historical reference)
|
- **NOTE:** Simple in-memory cache for MVP. For production: use Redis or similar
|
||||||
- Push operations use fresh discovery from profile/default
|
- Prevents expensive PDS manifest lookups on every blob request during pull
|
||||||
- Function: `db.GetLatestHoldDIDForRepo(did, repository)` in `pkg/appview/db/queries.go`
|
|
||||||
|
|
||||||
**proxy_blob_store.go**: External storage proxy (routes to hold via XRPC)
|
**proxy_blob_store.go**: External storage proxy (routes to hold via XRPC)
|
||||||
- Resolves hold DID → HTTP URL for XRPC requests (did:web resolution)
|
- Resolves hold DID → HTTP URL for XRPC requests (did:web resolution)
|
||||||
@@ -663,8 +604,7 @@ See `.env.hold.example` for all available options. Key environment variables:
|
|||||||
|
|
||||||
**General:**
|
**General:**
|
||||||
- Middleware is in `pkg/appview/middleware/` (auth.go, registry.go)
|
- Middleware is in `pkg/appview/middleware/` (auth.go, registry.go)
|
||||||
- Storage routing is in `pkg/appview/storage/` (routing_repository.go, proxy_blob_store.go)
|
- Storage routing is in `pkg/appview/storage/` (routing_repository.go, proxy_blob_store.go, hold_cache.go)
|
||||||
- Hold DID lookups use database queries (no in-memory caching)
|
|
||||||
- Storage drivers imported as `_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"`
|
- Storage drivers imported as `_ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws"`
|
||||||
- Hold service reuses distribution's driver factory for multi-backend support
|
- Hold service reuses distribution's driver factory for multi-backend support
|
||||||
|
|
||||||
|
|||||||
@@ -119,11 +119,10 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
slog.Info("TEST_MODE enabled - will use HTTP for local DID resolution and transition:generic scope")
|
slog.Info("TEST_MODE enabled - will use HTTP for local DID resolution and transition:generic scope")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create OAuth client app (automatically configures confidential client for production)
|
// Create OAuth app (automatically configures confidential client for production)
|
||||||
desiredScopes := oauth.GetDefaultScopes(defaultHoldDID)
|
oauthApp, err := oauth.NewApp(baseURL, oauthStore, defaultHoldDID, cfg.Server.OAuthKeyPath, cfg.Server.ClientName)
|
||||||
oauthClientApp, err := oauth.NewClientApp(baseURL, oauthStore, desiredScopes, cfg.Server.OAuthKeyPath, cfg.Server.ClientName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create OAuth client app: %w", err)
|
return fmt.Errorf("failed to create OAuth app: %w", err)
|
||||||
}
|
}
|
||||||
if testMode {
|
if testMode {
|
||||||
slog.Info("Using OAuth scopes with transition:generic (test mode)")
|
slog.Info("Using OAuth scopes with transition:generic (test mode)")
|
||||||
@@ -133,6 +132,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
// Invalidate sessions with mismatched scopes on startup
|
// Invalidate sessions with mismatched scopes on startup
|
||||||
// This ensures all users have the latest required scopes after deployment
|
// This ensures all users have the latest required scopes after deployment
|
||||||
|
desiredScopes := oauth.GetDefaultScopes(defaultHoldDID)
|
||||||
invalidatedCount, err := oauthStore.InvalidateSessionsWithMismatchedScopes(context.Background(), desiredScopes)
|
invalidatedCount, err := oauthStore.InvalidateSessionsWithMismatchedScopes(context.Background(), desiredScopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("Failed to invalidate sessions with mismatched scopes", "error", err)
|
slog.Warn("Failed to invalidate sessions with mismatched scopes", "error", err)
|
||||||
@@ -141,7 +141,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create oauth token refresher
|
// Create oauth token refresher
|
||||||
refresher := oauth.NewRefresher(oauthClientApp)
|
refresher := oauth.NewRefresher(oauthApp)
|
||||||
|
|
||||||
// Wire up UI session store to refresher so it can invalidate UI sessions on OAuth failures
|
// Wire up UI session store to refresher so it can invalidate UI sessions on OAuth failures
|
||||||
if uiSessionStore != nil {
|
if uiSessionStore != nil {
|
||||||
@@ -189,7 +189,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
Database: uiDatabase,
|
Database: uiDatabase,
|
||||||
ReadOnlyDB: uiReadOnlyDB,
|
ReadOnlyDB: uiReadOnlyDB,
|
||||||
SessionStore: uiSessionStore,
|
SessionStore: uiSessionStore,
|
||||||
OAuthClientApp: oauthClientApp,
|
OAuthApp: oauthApp,
|
||||||
OAuthStore: oauthStore,
|
OAuthStore: oauthStore,
|
||||||
Refresher: refresher,
|
Refresher: refresher,
|
||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
@@ -202,7 +202,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create OAuth server
|
// Create OAuth server
|
||||||
oauthServer := oauth.NewServer(oauthClientApp)
|
oauthServer := oauth.NewServer(oauthApp)
|
||||||
// Connect server to refresher for cache invalidation
|
// Connect server to refresher for cache invalidation
|
||||||
oauthServer.SetRefresher(refresher)
|
oauthServer.SetRefresher(refresher)
|
||||||
// Connect UI session store for web login
|
// Connect UI session store for web login
|
||||||
@@ -223,7 +223,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resume OAuth session to get authenticated client
|
// Resume OAuth session to get authenticated client
|
||||||
session, err := oauthClientApp.ResumeSession(ctx, didParsed, sessionID)
|
session, err := oauthApp.ResumeSession(ctx, didParsed, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("Failed to resume session", "component", "appview/callback", "did", did, "error", err)
|
slog.Warn("Failed to resume session", "component", "appview/callback", "did", did, "error", err)
|
||||||
// Fallback: update user without avatar
|
// Fallback: update user without avatar
|
||||||
@@ -320,12 +320,10 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
// Register crew regardless of migration (outside the migration block)
|
// Register crew regardless of migration (outside the migration block)
|
||||||
// Run in background to avoid blocking OAuth callback if hold is offline
|
// Run in background to avoid blocking OAuth callback if hold is offline
|
||||||
// Use background context - don't inherit request context which gets canceled on response
|
|
||||||
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
|
slog.Debug("Attempting crew registration", "component", "appview/callback", "did", did, "hold_did", holdDID)
|
||||||
go func(client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
go func(ctx context.Context, client *atproto.Client, refresher *oauth.Refresher, holdDID string) {
|
||||||
ctx := context.Background()
|
|
||||||
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
storage.EnsureCrewMembership(ctx, client, refresher, holdDID)
|
||||||
}(client, refresher, holdDID)
|
}(ctx, client, refresher, holdDID)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,7 +383,7 @@ func serveRegistry(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
// OAuth client metadata endpoint
|
// OAuth client metadata endpoint
|
||||||
mainRouter.Get("/client-metadata.json", func(w http.ResponseWriter, r *http.Request) {
|
mainRouter.Get("/client-metadata.json", func(w http.ResponseWriter, r *http.Request) {
|
||||||
config := oauthClientApp.Config
|
config := oauthApp.GetConfig()
|
||||||
metadata := config.ClientMetadata()
|
metadata := config.ClientMetadata()
|
||||||
|
|
||||||
// For confidential clients, ensure JWKS is included
|
// For confidential clients, ensure JWKS is included
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ These components are essential to registry operation and still need coverage.
|
|||||||
|
|
||||||
OAuth implementation has test files but many functions remain untested.
|
OAuth implementation has test files but many functions remain untested.
|
||||||
|
|
||||||
#### client.go - Session Management (Refresher) (Partial coverage)
|
#### refresher.go (Partial coverage)
|
||||||
|
|
||||||
**Well-covered:**
|
**Well-covered:**
|
||||||
- `NewRefresher()` - 100% ✅
|
- `NewRefresher()` - 100% ✅
|
||||||
@@ -227,8 +227,6 @@ OAuth implementation has test files but many functions remain untested.
|
|||||||
- Session retrieval and caching
|
- Session retrieval and caching
|
||||||
- Token refresh flow
|
- Token refresh flow
|
||||||
- Concurrent refresh handling (per-DID locking)
|
- Concurrent refresh handling (per-DID locking)
|
||||||
|
|
||||||
**Note:** Refresher functionality merged into client.go (previously separate refresher.go file)
|
|
||||||
- Cache expiration
|
- Cache expiration
|
||||||
- Error handling for failed refreshes
|
- Error handling for failed refreshes
|
||||||
|
|
||||||
@@ -511,9 +509,8 @@ UI initialization and setup. Low priority.
|
|||||||
**In Progress:**
|
**In Progress:**
|
||||||
9. 🔴 `pkg/appview/db/*` - Database layer (41.2%, needs improvement)
|
9. 🔴 `pkg/appview/db/*` - Database layer (41.2%, needs improvement)
|
||||||
- queries.go, session_store.go, device_store.go
|
- queries.go, session_store.go, device_store.go
|
||||||
10. 🔴 `pkg/auth/oauth/client.go` - Session management (Refresher) (Partial → 70%+)
|
10. 🔴 `pkg/auth/oauth/refresher.go` - Token refresh (Partial → 70%+)
|
||||||
- `GetSession()`, `resumeSession()` (currently 0%)
|
- `GetSession()`, `resumeSession()` (currently 0%)
|
||||||
- Note: Refresher merged into client.go
|
|
||||||
11. 🔴 `pkg/auth/oauth/server.go` - OAuth endpoints (50.7%, continue improvements)
|
11. 🔴 `pkg/auth/oauth/server.go` - OAuth endpoints (50.7%, continue improvements)
|
||||||
- `ServeCallback()` at 16.3% needs major improvement
|
- `ServeCallback()` at 16.3% needs major improvement
|
||||||
12. 🔴 `pkg/appview/storage/crew.go` - Crew validation (11.1% → 80%+)
|
12. 🔴 `pkg/appview/storage/crew.go` - Crew validation (11.1% → 80%+)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -4,7 +4,7 @@ go 1.24.7
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/aws/aws-sdk-go v1.55.5
|
github.com/aws/aws-sdk-go v1.55.5
|
||||||
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61
|
github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb
|
||||||
github.com/distribution/distribution/v3 v3.0.0
|
github.com/distribution/distribution/v3 v3.0.0
|
||||||
github.com/distribution/reference v0.6.0
|
github.com/distribution/reference v0.6.0
|
||||||
github.com/earthboundkid/versioninfo/v2 v2.24.1
|
github.com/earthboundkid/versioninfo/v2 v2.24.1
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -20,8 +20,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||||
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61 h1:lU2NnyuvevVWtE35sb4xWBp1AQxa1Sv4XhexiWlrWng=
|
github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb h1:zzyqB1W/itfdIA5cnOZ7IFCJ6QtqwOsXltmLunL4sHw=
|
||||||
github.com/bluesky-social/indigo v0.0.0-20251031012455-0b4bd2478a61/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng=
|
github.com/bluesky-social/indigo v0.0.0-20251021193747-543ab1124beb/go.mod h1:GuGAU33qKulpZCZNPcUeIQ4RW6KzNvOy7s8MSUXbAng=
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||||
github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70=
|
github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70=
|
||||||
|
|||||||
@@ -112,25 +112,6 @@ func (s *OAuthStore) DeleteSessionsForDID(ctx context.Context, did string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteOldSessionsForDID removes all sessions for a DID except the specified session to keep
|
|
||||||
// This is used during OAuth callback to clean up stale sessions with expired refresh tokens
|
|
||||||
func (s *OAuthStore) DeleteOldSessionsForDID(ctx context.Context, did string, keepSessionID string) error {
|
|
||||||
result, err := s.db.ExecContext(ctx, `
|
|
||||||
DELETE FROM oauth_sessions WHERE account_did = ? AND session_id != ?
|
|
||||||
`, did, keepSessionID)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete old sessions for DID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
deleted, _ := result.RowsAffected()
|
|
||||||
if deleted > 0 {
|
|
||||||
slog.Info("Deleted old OAuth sessions for DID", "count", deleted, "did", did, "kept", keepSessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthRequestInfo retrieves authentication request data by state
|
// GetAuthRequestInfo retrieves authentication request data by state
|
||||||
func (s *OAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
|
func (s *OAuthStore) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
|
||||||
var requestDataJSON string
|
var requestDataJSON string
|
||||||
|
|||||||
@@ -724,30 +724,6 @@ func GetNewestManifestForRepo(db *sql.DB, did, repository string) (*Manifest, er
|
|||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
|
|
||||||
// Returns empty string if no manifests exist (e.g., first push)
|
|
||||||
// This is used instead of the in-memory cache to determine which hold to use for blob operations
|
|
||||||
func GetLatestHoldDIDForRepo(db *sql.DB, did, repository string) (string, error) {
|
|
||||||
var holdDID string
|
|
||||||
err := db.QueryRow(`
|
|
||||||
SELECT hold_endpoint
|
|
||||||
FROM manifests
|
|
||||||
WHERE did = ? AND repository = ?
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT 1
|
|
||||||
`, did, repository).Scan(&holdDID)
|
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// No manifests yet - return empty string (first push case)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return holdDID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRepositoriesForDID returns all unique repository names for a DID
|
// GetRepositoriesForDID returns all unique repository names for a DID
|
||||||
// Used by backfill to reconcile annotations for all repositories
|
// Used by backfill to reconcile annotations for all repositories
|
||||||
func GetRepositoriesForDID(db *sql.DB, did string) ([]string, error) {
|
func GetRepositoriesForDID(db *sql.DB, did string) ([]string, error) {
|
||||||
@@ -1600,11 +1576,6 @@ func (m *MetricsDB) IncrementPushCount(did, repository string) error {
|
|||||||
return IncrementPushCount(m.db, did, repository)
|
return IncrementPushCount(m.db, did, repository)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository
|
|
||||||
func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
|
||||||
return GetLatestHoldDIDForRepo(m.db, did, repository)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls
|
// GetFeaturedRepositories fetches top repositories sorted by stars and pulls
|
||||||
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
|
func GetFeaturedRepositories(db *sql.DB, limit int, currentUserDID string) ([]FeaturedRepository, error) {
|
||||||
query := `
|
query := `
|
||||||
|
|||||||
@@ -6,16 +6,15 @@ import (
|
|||||||
|
|
||||||
"atcr.io/pkg/appview/db"
|
"atcr.io/pkg/appview/db"
|
||||||
"atcr.io/pkg/auth/oauth"
|
"atcr.io/pkg/auth/oauth"
|
||||||
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
|
||||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogoutHandler handles user logout with proper OAuth token revocation
|
// LogoutHandler handles user logout with proper OAuth token revocation
|
||||||
type LogoutHandler struct {
|
type LogoutHandler struct {
|
||||||
OAuthClientApp *indigooauth.ClientApp
|
OAuthApp *oauth.App
|
||||||
Refresher *oauth.Refresher
|
Refresher *oauth.Refresher
|
||||||
SessionStore *db.SessionStore
|
SessionStore *db.SessionStore
|
||||||
OAuthStore *db.OAuthStore
|
OAuthStore *db.OAuthStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -38,13 +37,17 @@ func (h *LogoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Attempt to revoke OAuth tokens on PDS side
|
// Attempt to revoke OAuth tokens on PDS side
|
||||||
if uiSession.OAuthSessionID != "" {
|
if uiSession.OAuthSessionID != "" {
|
||||||
// Call indigo's Logout to revoke tokens on PDS
|
// Call indigo's Logout to revoke tokens on PDS
|
||||||
if err := h.OAuthClientApp.Logout(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
if err := h.OAuthApp.GetClientApp().Logout(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
||||||
// Log error but don't block logout - best effort revocation
|
// Log error but don't block logout - best effort revocation
|
||||||
slog.Warn("Failed to revoke OAuth tokens on PDS", "component", "logout", "did", uiSession.DID, "error", err)
|
slog.Warn("Failed to revoke OAuth tokens on PDS", "component", "logout", "did", uiSession.DID, "error", err)
|
||||||
} else {
|
} else {
|
||||||
slog.Info("Successfully revoked OAuth tokens on PDS", "component", "logout", "did", uiSession.DID)
|
slog.Info("Successfully revoked OAuth tokens on PDS", "component", "logout", "did", uiSession.DID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Invalidate refresher cache to clear local access tokens
|
||||||
|
h.Refresher.InvalidateSession(uiSession.DID)
|
||||||
|
slog.Info("Invalidated local OAuth cache", "component", "logout", "did", uiSession.DID)
|
||||||
|
|
||||||
// Delete OAuth session from database (cleanup, might already be done by Logout)
|
// Delete OAuth session from database (cleanup, might already be done by Logout)
|
||||||
if err := h.OAuthStore.DeleteSession(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
if err := h.OAuthStore.DeleteSession(r.Context(), did, uiSession.OAuthSessionID); err != nil {
|
||||||
slog.Warn("Failed to delete OAuth session from database", "component", "logout", "error", err)
|
slog.Warn("Failed to delete OAuth session from database", "component", "logout", "error", err)
|
||||||
|
|||||||
@@ -107,15 +107,6 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
|||||||
// Detect manifest type
|
// Detect manifest type
|
||||||
isManifestList := len(manifestRecord.Manifests) > 0
|
isManifestList := len(manifestRecord.Manifests) > 0
|
||||||
|
|
||||||
// Extract hold DID from manifest (with fallback for legacy manifests)
|
|
||||||
// New manifests use holdDid field (DID format)
|
|
||||||
// Old manifests use holdEndpoint field (URL format) - convert to DID
|
|
||||||
holdDID := manifestRecord.HoldDID
|
|
||||||
if holdDID == "" && manifestRecord.HoldEndpoint != "" {
|
|
||||||
// Legacy manifest - convert URL to DID
|
|
||||||
holdDID = atproto.ResolveHoldDIDFromURL(manifestRecord.HoldEndpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare manifest for insertion (WITHOUT annotation fields)
|
// Prepare manifest for insertion (WITHOUT annotation fields)
|
||||||
manifest := &db.Manifest{
|
manifest := &db.Manifest{
|
||||||
DID: did,
|
DID: did,
|
||||||
@@ -123,7 +114,7 @@ func (p *Processor) ProcessManifest(ctx context.Context, did string, recordData
|
|||||||
Digest: manifestRecord.Digest,
|
Digest: manifestRecord.Digest,
|
||||||
MediaType: manifestRecord.MediaType,
|
MediaType: manifestRecord.MediaType,
|
||||||
SchemaVersion: manifestRecord.SchemaVersion,
|
SchemaVersion: manifestRecord.SchemaVersion,
|
||||||
HoldEndpoint: holdDID,
|
HoldEndpoint: manifestRecord.HoldEndpoint,
|
||||||
CreatedAt: manifestRecord.CreatedAt,
|
CreatedAt: manifestRecord.CreatedAt,
|
||||||
// Annotations removed - stored separately in repository_annotations table
|
// Annotations removed - stored separately in repository_annotations table
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/distribution/distribution/v3"
|
"github.com/distribution/distribution/v3"
|
||||||
"github.com/distribution/distribution/v3/registry/api/errcode"
|
"github.com/distribution/distribution/v3/registry/api/errcode"
|
||||||
@@ -68,6 +69,7 @@ type NamespaceResolver struct {
|
|||||||
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io")
|
||||||
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
baseURL string // Base URL for error messages (e.g., "https://atcr.io")
|
||||||
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
testMode bool // If true, fallback to default hold when user's hold is unreachable
|
||||||
|
repositories sync.Map // Cache of RoutingRepository instances by key (did:reponame)
|
||||||
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
refresher *oauth.Refresher // OAuth session manager (copied from global on init)
|
||||||
database storage.DatabaseMetrics // Metrics database (copied from global on init)
|
database storage.DatabaseMetrics // Metrics database (copied from global on init)
|
||||||
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init)
|
||||||
@@ -222,15 +224,20 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
|||||||
// Example: "evan.jarrett.net/debian" -> store as "debian"
|
// Example: "evan.jarrett.net/debian" -> store as "debian"
|
||||||
repositoryName := imageName
|
repositoryName := imageName
|
||||||
|
|
||||||
|
// Cache key is DID + repository name
|
||||||
|
cacheKey := did + ":" + repositoryName
|
||||||
|
|
||||||
|
// Check cache first and update service token
|
||||||
|
if cached, ok := nr.repositories.Load(cacheKey); ok {
|
||||||
|
cachedRepo := cached.(*storage.RoutingRepository)
|
||||||
|
// Always update the service token even for cached repos (token may have been renewed)
|
||||||
|
cachedRepo.Ctx.ServiceToken = serviceToken
|
||||||
|
return cachedRepo, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Create routing repository - routes manifests to ATProto, blobs to hold service
|
// Create routing repository - routes manifests to ATProto, blobs to hold service
|
||||||
// The registry is stateless - no local storage is used
|
// The registry is stateless - no local storage is used
|
||||||
// Bundle all context into a single RegistryContext struct
|
// Bundle all context into a single RegistryContext struct
|
||||||
//
|
|
||||||
// NOTE: We create a fresh RoutingRepository on every request (no caching) because:
|
|
||||||
// 1. Each layer upload is a separate HTTP request (possibly different process)
|
|
||||||
// 2. OAuth sessions can be refreshed/invalidated between requests
|
|
||||||
// 3. The refresher already caches sessions efficiently (in-memory + DB)
|
|
||||||
// 4. Caching the repository with a stale ATProtoClient causes refresh token errors
|
|
||||||
registryCtx := &storage.RegistryContext{
|
registryCtx := &storage.RegistryContext{
|
||||||
DID: did,
|
DID: did,
|
||||||
Handle: handle,
|
Handle: handle,
|
||||||
@@ -244,8 +251,12 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
|||||||
Refresher: nr.refresher,
|
Refresher: nr.refresher,
|
||||||
ReadmeCache: nr.readmeCache,
|
ReadmeCache: nr.readmeCache,
|
||||||
}
|
}
|
||||||
|
routingRepo := storage.NewRoutingRepository(repo, registryCtx)
|
||||||
|
|
||||||
return storage.NewRoutingRepository(repo, registryCtx), nil
|
// Cache the repository
|
||||||
|
nr.repositories.Store(cacheKey, routingRepo)
|
||||||
|
|
||||||
|
return routingRepo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repositories delegates to underlying namespace
|
// Repositories delegates to underlying namespace
|
||||||
|
|||||||
@@ -13,22 +13,21 @@ import (
|
|||||||
"atcr.io/pkg/appview/readme"
|
"atcr.io/pkg/appview/readme"
|
||||||
"atcr.io/pkg/auth/oauth"
|
"atcr.io/pkg/auth/oauth"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UIDependencies contains all dependencies needed for UI route registration
|
// UIDependencies contains all dependencies needed for UI route registration
|
||||||
type UIDependencies struct {
|
type UIDependencies struct {
|
||||||
Database *sql.DB
|
Database *sql.DB
|
||||||
ReadOnlyDB *sql.DB
|
ReadOnlyDB *sql.DB
|
||||||
SessionStore *db.SessionStore
|
SessionStore *db.SessionStore
|
||||||
OAuthClientApp *indigooauth.ClientApp
|
OAuthApp *oauth.App
|
||||||
OAuthStore *db.OAuthStore
|
OAuthStore *db.OAuthStore
|
||||||
Refresher *oauth.Refresher
|
Refresher *oauth.Refresher
|
||||||
BaseURL string
|
BaseURL string
|
||||||
DeviceStore *db.DeviceStore
|
DeviceStore *db.DeviceStore
|
||||||
HealthChecker *holdhealth.Checker
|
HealthChecker *holdhealth.Checker
|
||||||
ReadmeCache *readme.Cache
|
ReadmeCache *readme.Cache
|
||||||
Templates *template.Template
|
Templates *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterUIRoutes registers all web UI and API routes on the provided router
|
// RegisterUIRoutes registers all web UI and API routes on the provided router
|
||||||
@@ -91,7 +90,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
router.Get("/api/stats/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||||
&uihandlers.GetStatsHandler{
|
&uihandlers.GetStatsHandler{
|
||||||
DB: deps.ReadOnlyDB,
|
DB: deps.ReadOnlyDB,
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
},
|
},
|
||||||
).ServeHTTP)
|
).ServeHTTP)
|
||||||
|
|
||||||
@@ -99,7 +98,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
router.Post("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||||
&uihandlers.StarRepositoryHandler{
|
&uihandlers.StarRepositoryHandler{
|
||||||
DB: deps.Database, // Needs write access
|
DB: deps.Database, // Needs write access
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
Refresher: deps.Refresher,
|
Refresher: deps.Refresher,
|
||||||
},
|
},
|
||||||
).ServeHTTP)
|
).ServeHTTP)
|
||||||
@@ -107,7 +106,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
router.Delete("/api/stars/{handle}/{repository}", middleware.RequireAuth(deps.SessionStore, deps.Database)(
|
||||||
&uihandlers.UnstarRepositoryHandler{
|
&uihandlers.UnstarRepositoryHandler{
|
||||||
DB: deps.Database, // Needs write access
|
DB: deps.Database, // Needs write access
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
Refresher: deps.Refresher,
|
Refresher: deps.Refresher,
|
||||||
},
|
},
|
||||||
).ServeHTTP)
|
).ServeHTTP)
|
||||||
@@ -115,7 +114,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
router.Get("/api/stars/{handle}/{repository}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||||
&uihandlers.CheckStarHandler{
|
&uihandlers.CheckStarHandler{
|
||||||
DB: deps.ReadOnlyDB, // Read-only check
|
DB: deps.ReadOnlyDB, // Read-only check
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
Refresher: deps.Refresher,
|
Refresher: deps.Refresher,
|
||||||
},
|
},
|
||||||
).ServeHTTP)
|
).ServeHTTP)
|
||||||
@@ -124,7 +123,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
router.Get("/api/manifests/{handle}/{repository}/{digest}", middleware.OptionalAuth(deps.SessionStore, deps.Database)(
|
||||||
&uihandlers.ManifestDetailHandler{
|
&uihandlers.ManifestDetailHandler{
|
||||||
DB: deps.ReadOnlyDB,
|
DB: deps.ReadOnlyDB,
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
},
|
},
|
||||||
).ServeHTTP)
|
).ServeHTTP)
|
||||||
|
|
||||||
@@ -146,7 +145,7 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
DB: deps.ReadOnlyDB,
|
DB: deps.ReadOnlyDB,
|
||||||
Templates: deps.Templates,
|
Templates: deps.Templates,
|
||||||
RegistryURL: registryURL,
|
RegistryURL: registryURL,
|
||||||
Directory: deps.OAuthClientApp.Dir,
|
Directory: deps.OAuthApp.Directory(),
|
||||||
Refresher: deps.Refresher,
|
Refresher: deps.Refresher,
|
||||||
HealthChecker: deps.HealthChecker,
|
HealthChecker: deps.HealthChecker,
|
||||||
ReadmeCache: deps.ReadmeCache,
|
ReadmeCache: deps.ReadmeCache,
|
||||||
@@ -203,10 +202,10 @@ func RegisterUIRoutes(router chi.Router, deps UIDependencies) {
|
|||||||
// Logout endpoint (supports both GET and POST)
|
// Logout endpoint (supports both GET and POST)
|
||||||
// Properly revokes OAuth tokens on PDS side before clearing local session
|
// Properly revokes OAuth tokens on PDS side before clearing local session
|
||||||
logoutHandler := &uihandlers.LogoutHandler{
|
logoutHandler := &uihandlers.LogoutHandler{
|
||||||
OAuthClientApp: deps.OAuthClientApp,
|
OAuthApp: deps.OAuthApp,
|
||||||
Refresher: deps.Refresher,
|
Refresher: deps.Refresher,
|
||||||
SessionStore: deps.SessionStore,
|
SessionStore: deps.SessionStore,
|
||||||
OAuthStore: deps.OAuthStore,
|
OAuthStore: deps.OAuthStore,
|
||||||
}
|
}
|
||||||
router.Get("/auth/logout", logoutHandler.ServeHTTP)
|
router.Get("/auth/logout", logoutHandler.ServeHTTP)
|
||||||
router.Post("/auth/logout", logoutHandler.ServeHTTP)
|
router.Post("/auth/logout", logoutHandler.ServeHTTP)
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ import (
|
|||||||
"atcr.io/pkg/auth/oauth"
|
"atcr.io/pkg/auth/oauth"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs
|
// DatabaseMetrics interface for tracking pull/push counts
|
||||||
type DatabaseMetrics interface {
|
type DatabaseMetrics interface {
|
||||||
IncrementPullCount(did, repository string) error
|
IncrementPullCount(did, repository string) error
|
||||||
IncrementPushCount(did, repository string) error
|
IncrementPushCount(did, repository string) error
|
||||||
GetLatestHoldDIDForRepo(did, repository string) (string, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadmeCache interface for README content caching
|
// ReadmeCache interface for README content caching
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
|
||||||
// Return empty string for mock - tests can override if needed
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDatabaseMetrics) getPullCount() int {
|
func (m *mockDatabaseMetrics) getPullCount() int {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|||||||
98
pkg/appview/storage/hold_cache.go
Normal file
98
pkg/appview/storage/hold_cache.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HoldCache caches hold DIDs for (DID, repository) pairs
|
||||||
|
// This avoids expensive ATProto lookups on every blob request during pulls
|
||||||
|
//
|
||||||
|
// NOTE: This is a simple in-memory cache for MVP. For production deployments:
|
||||||
|
// - Use Redis or similar for distributed caching
|
||||||
|
// - Consider implementing cache size limits
|
||||||
|
// - Monitor memory usage under high load
|
||||||
|
type HoldCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[string]*holdCacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type holdCacheEntry struct {
|
||||||
|
holdDID string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalHoldCache = &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Start background cleanup goroutine
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
globalHoldCache.Cleanup()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGlobalHoldCache returns the global hold cache instance
|
||||||
|
func GetGlobalHoldCache() *HoldCache {
|
||||||
|
return globalHoldCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set stores a hold DID for a (DID, repository) pair with a TTL
|
||||||
|
func (c *HoldCache) Set(did, repository, holdDID string, ttl time.Duration) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
key := did + ":" + repository
|
||||||
|
c.cache[key] = &holdCacheEntry{
|
||||||
|
holdDID: holdDID,
|
||||||
|
expiresAt: time.Now().Add(ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a hold DID for a (DID, repository) pair
|
||||||
|
// Returns empty string and false if not found or expired
|
||||||
|
func (c *HoldCache) Get(did, repository string) (string, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
key := did + ":" + repository
|
||||||
|
entry, ok := c.cache[key]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if expired
|
||||||
|
if time.Now().After(entry.expiresAt) {
|
||||||
|
// Don't delete here (would need write lock), let cleanup handle it
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.holdDID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup removes expired entries (called automatically every 5 minutes)
|
||||||
|
func (c *HoldCache) Cleanup() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for key, entry := range c.cache {
|
||||||
|
if now.After(entry.expiresAt) {
|
||||||
|
delete(c.cache, key)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log cleanup stats for monitoring
|
||||||
|
if removed > 0 || len(c.cache) > 100 {
|
||||||
|
// Log if we removed entries OR if cache is growing large
|
||||||
|
// This helps identify if cache size is becoming a concern
|
||||||
|
println("Hold cache cleanup: removed", removed, "entries, remaining", len(c.cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
150
pkg/appview/storage/hold_cache_test.go
Normal file
150
pkg/appview/storage/hold_cache_test.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHoldCache_SetAndGet(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
did := "did:plc:test123"
|
||||||
|
repo := "myapp"
|
||||||
|
holdDID := "did:web:hold01.atcr.io"
|
||||||
|
ttl := 10 * time.Minute
|
||||||
|
|
||||||
|
// Set a value
|
||||||
|
cache.Set(did, repo, holdDID, ttl)
|
||||||
|
|
||||||
|
// Get the value - should succeed
|
||||||
|
gotHoldDID, ok := cache.Get(did, repo)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected Get to return true, got false")
|
||||||
|
}
|
||||||
|
if gotHoldDID != holdDID {
|
||||||
|
t.Errorf("Expected hold DID %q, got %q", holdDID, gotHoldDID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoldCache_GetNonExistent(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get non-existent value
|
||||||
|
_, ok := cache.Get("did:plc:nonexistent", "repo")
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected Get to return false for non-existent key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoldCache_ExpiredEntry(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
did := "did:plc:test123"
|
||||||
|
repo := "myapp"
|
||||||
|
holdDID := "did:web:hold01.atcr.io"
|
||||||
|
|
||||||
|
// Set with very short TTL
|
||||||
|
cache.Set(did, repo, holdDID, 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Wait for expiration
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
|
||||||
|
// Get should return false
|
||||||
|
_, ok := cache.Get(did, repo)
|
||||||
|
if ok {
|
||||||
|
t.Error("Expected Get to return false for expired entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoldCache_Cleanup(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add multiple entries with different TTLs
|
||||||
|
cache.Set("did:plc:1", "repo1", "hold1", 10*time.Millisecond)
|
||||||
|
cache.Set("did:plc:2", "repo2", "hold2", 1*time.Hour)
|
||||||
|
cache.Set("did:plc:3", "repo3", "hold3", 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Wait for some to expire
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
cache.Cleanup()
|
||||||
|
|
||||||
|
// Verify expired entries are removed
|
||||||
|
if _, ok := cache.Get("did:plc:1", "repo1"); ok {
|
||||||
|
t.Error("Expected expired entry 1 to be removed")
|
||||||
|
}
|
||||||
|
if _, ok := cache.Get("did:plc:3", "repo3"); ok {
|
||||||
|
t.Error("Expected expired entry 3 to be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify non-expired entry remains
|
||||||
|
if _, ok := cache.Get("did:plc:2", "repo2"); !ok {
|
||||||
|
t.Error("Expected non-expired entry to remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoldCache_ConcurrentAccess(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Concurrent writes
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
did := "did:plc:concurrent"
|
||||||
|
repo := "repo" + string(rune(id))
|
||||||
|
holdDID := "hold" + string(rune(id))
|
||||||
|
cache.Set(did, repo, holdDID, 1*time.Minute)
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent reads
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
repo := "repo" + string(rune(id))
|
||||||
|
cache.Get("did:plc:concurrent", repo)
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoldCache_KeyFormat(t *testing.T) {
|
||||||
|
cache := &HoldCache{
|
||||||
|
cache: make(map[string]*holdCacheEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
did := "did:plc:test"
|
||||||
|
repo := "myrepo"
|
||||||
|
holdDID := "did:web:hold"
|
||||||
|
|
||||||
|
cache.Set(did, repo, holdDID, 1*time.Minute)
|
||||||
|
|
||||||
|
// Verify the key is stored correctly (did:repo)
|
||||||
|
expectedKey := did + ":" + repo
|
||||||
|
if _, exists := cache.cache[expectedKey]; !exists {
|
||||||
|
t.Errorf("Expected key %q to exist in cache", expectedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add more comprehensive tests:
|
||||||
|
// - Test GetGlobalHoldCache()
|
||||||
|
// - Test cache size monitoring
|
||||||
|
// - Benchmark cache performance under load
|
||||||
|
// - Test cleanup goroutine timing
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
// Package storage implements the storage routing layer for AppView.
|
// Package storage implements the storage routing layer for AppView.
|
||||||
// It routes manifests to ATProto PDS (as io.atcr.manifest records) and
|
// It routes manifests to ATProto PDS (as io.atcr.manifest records) and
|
||||||
// blobs to hold services via XRPC, with database-based hold DID lookups.
|
// blobs to hold services via XRPC, with hold DID caching for efficient pulls.
|
||||||
// All storage operations are proxied - AppView stores nothing locally.
|
// All storage operations are proxied - AppView stores nothing locally.
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/distribution/distribution/v3"
|
"github.com/distribution/distribution/v3"
|
||||||
)
|
)
|
||||||
@@ -49,6 +50,17 @@ func (r *RoutingRepository) Manifests(ctx context.Context, options ...distributi
|
|||||||
manifestStore := r.manifestStore
|
manifestStore := r.manifestStore
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
// After any manifest operation, cache the hold DID for blob fetches
|
||||||
|
// We use a goroutine to avoid blocking, and check after a short delay to allow the operation to complete
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond) // Brief delay to let manifest fetch complete
|
||||||
|
if holdDID := manifestStore.GetLastFetchedHoldDID(); holdDID != "" {
|
||||||
|
// Cache for 10 minutes - should cover typical pull operations
|
||||||
|
GetGlobalHoldCache().Set(r.Ctx.DID, r.Ctx.Repository, holdDID, 10*time.Minute)
|
||||||
|
slog.Debug("Cached hold DID", "component", "storage/routing", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return manifestStore, nil
|
return manifestStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,23 +76,17 @@ func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
|
|||||||
return blobStore
|
return blobStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// For pull operations, check database for hold DID from the most recent manifest
|
// For pull operations, check if we have a cached hold DID from a recent manifest fetch
|
||||||
// This ensures blobs are fetched from the hold recorded in the manifest, not re-discovered
|
// This ensures blobs are fetched from the hold recorded in the manifest, not re-discovered
|
||||||
holdDID := r.Ctx.HoldDID // Default to discovery-based DID
|
holdDID := r.Ctx.HoldDID // Default to discovery-based DID
|
||||||
holdSource := "discovery"
|
|
||||||
|
|
||||||
if r.Ctx.Database != nil {
|
if cachedHoldDID, ok := GetGlobalHoldCache().Get(r.Ctx.DID, r.Ctx.Repository); ok {
|
||||||
// Query database for the latest manifest's hold DID
|
// Use cached hold DID from manifest
|
||||||
if dbHoldDID, err := r.Ctx.Database.GetLatestHoldDIDForRepo(r.Ctx.DID, r.Ctx.Repository); err == nil && dbHoldDID != "" {
|
holdDID = cachedHoldDID
|
||||||
// Use hold DID from database (pull case - use historical reference)
|
slog.Debug("Using cached hold from manifest", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", cachedHoldDID)
|
||||||
holdDID = dbHoldDID
|
} else {
|
||||||
holdSource = "database"
|
// No cached hold, use discovery-based DID (for push or first pull)
|
||||||
slog.Debug("Using hold from database manifest", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", dbHoldDID)
|
slog.Debug("Using discovery-based hold", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID)
|
||||||
} else if err != nil {
|
|
||||||
// Log error but don't fail - fall back to discovery-based DID
|
|
||||||
slog.Warn("Failed to query database for hold DID", "component", "storage/blobs", "error", err)
|
|
||||||
}
|
|
||||||
// If dbHoldDID is empty (no manifests yet), fall through to use discovery-based DID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if holdDID == "" {
|
if holdDID == "" {
|
||||||
@@ -88,9 +94,7 @@ func (r *RoutingRepository) Blobs(ctx context.Context) distribution.BlobStore {
|
|||||||
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
|
panic("hold DID not set in RegistryContext - ensure default_hold_did is configured in middleware")
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("Using hold DID for blobs", "component", "storage/blobs", "did", r.Ctx.DID, "repo", r.Ctx.Repository, "hold", holdDID, "source", holdSource)
|
// Update context with the correct hold DID (may be cached or discovered)
|
||||||
|
|
||||||
// Update context with the correct hold DID (may be from database or discovered)
|
|
||||||
r.Ctx.HoldDID = holdDID
|
r.Ctx.HoldDID = holdDID
|
||||||
|
|
||||||
// Create and cache proxy blob store
|
// Create and cache proxy blob store
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/distribution/distribution/v3"
|
"github.com/distribution/distribution/v3"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -12,27 +13,6 @@ import (
|
|||||||
"atcr.io/pkg/atproto"
|
"atcr.io/pkg/atproto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockDatabase is a simple mock for testing
|
|
||||||
type mockDatabase struct {
|
|
||||||
holdDID string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDatabase) IncrementPullCount(did, repository string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDatabase) IncrementPushCount(did, repository string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockDatabase) GetLatestHoldDIDForRepo(did, repository string) (string, error) {
|
|
||||||
if m.err != nil {
|
|
||||||
return "", m.err
|
|
||||||
}
|
|
||||||
return m.holdDID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewRoutingRepository(t *testing.T) {
|
func TestNewRoutingRepository(t *testing.T) {
|
||||||
ctx := &RegistryContext{
|
ctx := &RegistryContext{
|
||||||
DID: "did:plc:test123",
|
DID: "did:plc:test123",
|
||||||
@@ -109,36 +89,38 @@ func TestRoutingRepository_ManifestStoreCaching(t *testing.T) {
|
|||||||
assert.NotNil(t, repo.manifestStore)
|
assert.NotNil(t, repo.manifestStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoutingRepository_Blobs_WithDatabase tests blob store with database hold DID
|
// TestRoutingRepository_Blobs_WithCache tests blob store with cached hold DID
|
||||||
func TestRoutingRepository_Blobs_WithDatabase(t *testing.T) {
|
func TestRoutingRepository_Blobs_WithCache(t *testing.T) {
|
||||||
dbHoldDID := "did:web:database.hold.io"
|
// Pre-populate the hold cache
|
||||||
|
cache := GetGlobalHoldCache()
|
||||||
|
cachedHoldDID := "did:web:cached.hold.io"
|
||||||
|
cache.Set("did:plc:test123", "myapp", cachedHoldDID, 10*time.Minute)
|
||||||
|
|
||||||
ctx := &RegistryContext{
|
ctx := &RegistryContext{
|
||||||
DID: "did:plc:test123",
|
DID: "did:plc:test123",
|
||||||
Repository: "myapp",
|
Repository: "myapp",
|
||||||
HoldDID: "did:web:default.hold.io", // Discovery-based hold (should be overridden)
|
HoldDID: "did:web:default.hold.io", // Discovery-based hold (should be overridden)
|
||||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||||
Database: &mockDatabase{holdDID: dbHoldDID},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := NewRoutingRepository(nil, ctx)
|
repo := NewRoutingRepository(nil, ctx)
|
||||||
blobStore := repo.Blobs(context.Background())
|
blobStore := repo.Blobs(context.Background())
|
||||||
|
|
||||||
assert.NotNil(t, blobStore)
|
assert.NotNil(t, blobStore)
|
||||||
// Verify the hold DID was updated to use the database value
|
// Verify the hold DID was updated to use the cached value
|
||||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "should use database hold DID")
|
assert.Equal(t, cachedHoldDID, repo.Ctx.HoldDID, "should use cached hold DID")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoutingRepository_Blobs_WithoutDatabase tests blob store with discovery-based hold
|
// TestRoutingRepository_Blobs_WithoutCache tests blob store with discovery-based hold
|
||||||
func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
|
func TestRoutingRepository_Blobs_WithoutCache(t *testing.T) {
|
||||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
discoveryHoldDID := "did:web:discovery.hold.io"
|
||||||
|
|
||||||
|
// Use a different DID/repo to avoid cache contamination from other tests
|
||||||
ctx := &RegistryContext{
|
ctx := &RegistryContext{
|
||||||
DID: "did:plc:nocache456",
|
DID: "did:plc:nocache456",
|
||||||
Repository: "uncached-app",
|
Repository: "uncached-app",
|
||||||
HoldDID: discoveryHoldDID,
|
HoldDID: discoveryHoldDID,
|
||||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
|
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:nocache456", ""),
|
||||||
Database: nil, // No database
|
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := NewRoutingRepository(nil, ctx)
|
repo := NewRoutingRepository(nil, ctx)
|
||||||
@@ -149,26 +131,6 @@ func TestRoutingRepository_Blobs_WithoutDatabase(t *testing.T) {
|
|||||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
|
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should use discovery-based hold DID")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoutingRepository_Blobs_DatabaseEmptyFallback tests fallback when database returns empty hold DID
|
|
||||||
func TestRoutingRepository_Blobs_DatabaseEmptyFallback(t *testing.T) {
|
|
||||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
|
||||||
|
|
||||||
ctx := &RegistryContext{
|
|
||||||
DID: "did:plc:test123",
|
|
||||||
Repository: "newapp",
|
|
||||||
HoldDID: discoveryHoldDID,
|
|
||||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
|
||||||
Database: &mockDatabase{holdDID: ""}, // Empty string (no manifests yet)
|
|
||||||
}
|
|
||||||
|
|
||||||
repo := NewRoutingRepository(nil, ctx)
|
|
||||||
blobStore := repo.Blobs(context.Background())
|
|
||||||
|
|
||||||
assert.NotNil(t, blobStore)
|
|
||||||
// Verify the hold DID falls back to discovery-based
|
|
||||||
assert.Equal(t, discoveryHoldDID, repo.Ctx.HoldDID, "should fall back to discovery-based hold DID when database returns empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached
|
// TestRoutingRepository_BlobStoreCaching tests that blob store is cached
|
||||||
func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
|
func TestRoutingRepository_BlobStoreCaching(t *testing.T) {
|
||||||
ctx := &RegistryContext{
|
ctx := &RegistryContext{
|
||||||
@@ -292,23 +254,26 @@ func TestRoutingRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
assert.NotNil(t, cachedBlobStore)
|
assert.NotNil(t, cachedBlobStore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRoutingRepository_Blobs_Priority tests that database hold DID takes priority over discovery
|
// TestRoutingRepository_HoldCachePopulation tests that hold DID cache is populated after manifest fetch
|
||||||
func TestRoutingRepository_Blobs_Priority(t *testing.T) {
|
// Note: This test verifies the goroutine behavior with a delay
|
||||||
dbHoldDID := "did:web:database.hold.io"
|
func TestRoutingRepository_HoldCachePopulation(t *testing.T) {
|
||||||
discoveryHoldDID := "did:web:discovery.hold.io"
|
|
||||||
|
|
||||||
ctx := &RegistryContext{
|
ctx := &RegistryContext{
|
||||||
DID: "did:plc:test123",
|
DID: "did:plc:test123",
|
||||||
Repository: "myapp",
|
Repository: "myapp",
|
||||||
HoldDID: discoveryHoldDID, // Discovery-based hold
|
HoldDID: "did:web:hold01.atcr.io",
|
||||||
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
ATProtoClient: atproto.NewClient("https://pds.example.com", "did:plc:test123", ""),
|
||||||
Database: &mockDatabase{holdDID: dbHoldDID}, // Database has a different hold DID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
repo := NewRoutingRepository(nil, ctx)
|
repo := NewRoutingRepository(nil, ctx)
|
||||||
blobStore := repo.Blobs(context.Background())
|
|
||||||
|
|
||||||
assert.NotNil(t, blobStore)
|
// Create manifest store (which triggers the cache population goroutine)
|
||||||
// Database hold DID should take priority over discovery
|
_, err := repo.Manifests(context.Background())
|
||||||
assert.Equal(t, dbHoldDID, repo.Ctx.HoldDID, "database hold DID should take priority over discovery")
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for goroutine to complete (it has a 100ms sleep)
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Note: We can't easily verify the cache was populated without a real manifest fetch
|
||||||
|
// The actual caching happens in GetLastFetchedHoldDID() which requires manifest operations
|
||||||
|
// This test primarily verifies the Manifests() call doesn't panic with the goroutine
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Package oauth provides OAuth client configuration and helper functions for ATCR.
|
// Package oauth provides OAuth client and flow implementation for ATCR.
|
||||||
// It provides helpers for setting up indigo's OAuth library with ATCR-specific
|
// It wraps indigo's OAuth library with ATCR-specific configuration,
|
||||||
// configuration, including default scopes, confidential client setup, and
|
// including default scopes, client metadata, token refreshing, and
|
||||||
// interactive browser-based authentication flows.
|
// interactive browser-based authentication flows.
|
||||||
package oauth
|
package oauth
|
||||||
|
|
||||||
@@ -8,19 +8,31 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"atcr.io/pkg/atproto"
|
"atcr.io/pkg/atproto"
|
||||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||||
|
"github.com/bluesky-social/indigo/atproto/identity"
|
||||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewClientApp creates an indigo OAuth ClientApp with ATCR-specific configuration
|
// App wraps indigo's ClientApp with ATCR-specific configuration
|
||||||
|
type App struct {
|
||||||
|
clientApp *oauth.ClientApp
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewApp creates a new OAuth app for ATCR with default scopes
|
||||||
|
func NewApp(baseURL string, store oauth.ClientAuthStore, holdDid string, keyPath string, clientName string) (*App, error) {
|
||||||
|
return NewAppWithScopes(baseURL, store, GetDefaultScopes(holdDid), keyPath, clientName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAppWithScopes creates a new OAuth app for ATCR with custom scopes
|
||||||
// Automatically configures confidential client for production deployments
|
// Automatically configures confidential client for production deployments
|
||||||
// keyPath specifies where to store/load the OAuth client P-256 key (ignored for localhost)
|
// keyPath specifies where to store/load the OAuth client P-256 key (ignored for localhost)
|
||||||
// clientName is added to OAuth client metadata (currently unused, reserved for future)
|
// clientName is added to OAuth client metadata
|
||||||
func NewClientApp(baseURL string, store oauth.ClientAuthStore, scopes []string, keyPath string, clientName string) (*oauth.ClientApp, error) {
|
func NewAppWithScopes(baseURL string, store oauth.ClientAuthStore, scopes []string, keyPath string, clientName string) (*App, error) {
|
||||||
var config oauth.ClientConfig
|
var config oauth.ClientConfig
|
||||||
redirectURI := RedirectURI(baseURL)
|
redirectURI := RedirectURI(baseURL)
|
||||||
|
|
||||||
@@ -56,7 +68,60 @@ func NewClientApp(baseURL string, store oauth.ClientAuthStore, scopes []string,
|
|||||||
clientApp := oauth.NewClientApp(&config, store)
|
clientApp := oauth.NewClientApp(&config, store)
|
||||||
clientApp.Dir = atproto.GetDirectory()
|
clientApp.Dir = atproto.GetDirectory()
|
||||||
|
|
||||||
return clientApp, nil
|
return &App{
|
||||||
|
clientApp: clientApp,
|
||||||
|
baseURL: baseURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *App) GetConfig() *oauth.ClientConfig {
|
||||||
|
return a.clientApp.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartAuthFlow initiates an OAuth authorization flow for a given handle
|
||||||
|
// Returns the authorization URL (state is stored in the auth store)
|
||||||
|
func (a *App) StartAuthFlow(ctx context.Context, handle string) (authURL string, err error) {
|
||||||
|
// Start auth flow with handle as identifier
|
||||||
|
// Indigo will resolve the handle internally
|
||||||
|
authURL, err = a.clientApp.StartAuthFlow(ctx, handle)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to start auth flow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return authURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessCallback processes an OAuth callback with authorization code and state
|
||||||
|
// Returns ClientSessionData which contains the session information
|
||||||
|
func (a *App) ProcessCallback(ctx context.Context, params url.Values) (*oauth.ClientSessionData, error) {
|
||||||
|
sessionData, err := a.clientApp.ProcessCallback(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to process OAuth callback: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResumeSession resumes an existing OAuth session
|
||||||
|
// Returns a ClientSession that can be used to make authenticated requests
|
||||||
|
func (a *App) ResumeSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSession, error) {
|
||||||
|
session, err := a.clientApp.ResumeSession(ctx, did, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resume session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientApp returns the underlying indigo ClientApp
|
||||||
|
// This is useful for advanced use cases that need direct access
|
||||||
|
func (a *App) GetClientApp() *oauth.ClientApp {
|
||||||
|
return a.clientApp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Directory returns the identity directory used by the OAuth app
|
||||||
|
func (a *App) Directory() identity.Directory {
|
||||||
|
return a.clientApp.Dir
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedirectURI returns the OAuth redirect URI for ATCR
|
// RedirectURI returns the OAuth redirect URI for ATCR
|
||||||
@@ -123,111 +188,3 @@ func ScopesMatch(stored, desired []string) bool {
|
|||||||
func isLocalhost(baseURL string) bool {
|
func isLocalhost(baseURL string) bool {
|
||||||
return strings.Contains(baseURL, "127.0.0.1") || strings.Contains(baseURL, "localhost")
|
return strings.Contains(baseURL, "127.0.0.1") || strings.Contains(baseURL, "localhost")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Session Management
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
// SessionCache represents a cached OAuth session
|
|
||||||
type SessionCache struct {
|
|
||||||
Session *oauth.ClientSession
|
|
||||||
SessionID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// UISessionStore interface for managing UI sessions
|
|
||||||
// Shared between refresher and server
|
|
||||||
type UISessionStore interface {
|
|
||||||
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
|
|
||||||
DeleteByDID(did string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresher manages OAuth sessions and token refresh for AppView
|
|
||||||
// Sessions are loaded fresh from database on every request (database is source of truth)
|
|
||||||
type Refresher struct {
|
|
||||||
clientApp *oauth.ClientApp
|
|
||||||
uiSessionStore UISessionStore // For invalidating UI sessions on OAuth failures
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRefresher creates a new session refresher
|
|
||||||
func NewRefresher(clientApp *oauth.ClientApp) *Refresher {
|
|
||||||
return &Refresher{
|
|
||||||
clientApp: clientApp,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUISessionStore sets the UI session store for invalidating sessions on OAuth failures
|
|
||||||
func (r *Refresher) SetUISessionStore(store UISessionStore) {
|
|
||||||
r.uiSessionStore = store
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSession gets a fresh OAuth session for a DID
|
|
||||||
// Loads session from database on every request (database is source of truth)
|
|
||||||
func (r *Refresher) GetSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
|
|
||||||
return r.resumeSession(ctx, did)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resumeSession loads a session from storage
|
|
||||||
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
|
|
||||||
// Parse DID
|
|
||||||
accountDID, err := syntax.ParseDID(did)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse DID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the latest session for this DID from SQLite store
|
|
||||||
// The store must implement GetLatestSessionForDID (returns newest by updated_at)
|
|
||||||
type sessionGetter interface {
|
|
||||||
GetLatestSessionForDID(ctx context.Context, did string) (*oauth.ClientSessionData, string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
getter, ok := r.clientApp.Store.(sessionGetter)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("store must implement GetLatestSessionForDID (SQLite store required)")
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionData, sessionID, err := getter.GetLatestSessionForDID(ctx, did)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("no session found for DID: %s", did)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate that session scopes match current desired scopes
|
|
||||||
desiredScopes := r.clientApp.Config.Scopes
|
|
||||||
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
|
|
||||||
slog.Debug("Scope mismatch, deleting session",
|
|
||||||
"did", did,
|
|
||||||
"storedScopes", sessionData.Scopes,
|
|
||||||
"desiredScopes", desiredScopes)
|
|
||||||
|
|
||||||
// Delete the session from database since scopes have changed
|
|
||||||
if err := r.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
|
|
||||||
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resume session
|
|
||||||
session, err := r.clientApp.ResumeSession(ctx, accountDID, sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to resume session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up callback to persist token updates to SQLite
|
|
||||||
// This ensures that when indigo automatically refreshes tokens,
|
|
||||||
// the new tokens are saved to the database immediately
|
|
||||||
session.PersistSessionCallback = func(callbackCtx context.Context, updatedData *oauth.ClientSessionData) {
|
|
||||||
if err := r.clientApp.Store.SaveSession(callbackCtx, *updatedData); err != nil {
|
|
||||||
slog.Error("Failed to persist OAuth session update",
|
|
||||||
"component", "oauth/refresher",
|
|
||||||
"did", did,
|
|
||||||
"sessionID", sessionID,
|
|
||||||
"error", err)
|
|
||||||
} else {
|
|
||||||
slog.Debug("Persisted OAuth token refresh to database",
|
|
||||||
"component", "oauth/refresher",
|
|
||||||
"did", did,
|
|
||||||
"sessionID", sessionID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewClientApp(t *testing.T) {
|
func TestNewApp(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
storePath := tmpDir + "/oauth-test.json"
|
storePath := tmpDir + "/oauth-test.json"
|
||||||
keyPath := tmpDir + "/oauth-key.bin"
|
keyPath := tmpDir + "/oauth-key.bin"
|
||||||
@@ -15,23 +15,23 @@ func TestNewClientApp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseURL := "http://localhost:5000"
|
baseURL := "http://localhost:5000"
|
||||||
scopes := GetDefaultScopes("*")
|
holdDID := "did:web:hold.example.com"
|
||||||
|
|
||||||
clientApp, err := NewClientApp(baseURL, store, scopes, keyPath, "AT Container Registry")
|
app, err := NewApp(baseURL, store, holdDID, keyPath, "AT Container Registry")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientApp == nil {
|
if app == nil {
|
||||||
t.Fatal("Expected non-nil clientApp")
|
t.Fatal("Expected non-nil app")
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientApp.Dir == nil {
|
if app.baseURL != baseURL {
|
||||||
t.Error("Expected directory to be set")
|
t.Errorf("Expected baseURL %q, got %q", baseURL, app.baseURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClientAppWithCustomScopes(t *testing.T) {
|
func TestNewAppWithScopes(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
storePath := tmpDir + "/oauth-test.json"
|
storePath := tmpDir + "/oauth-test.json"
|
||||||
keyPath := tmpDir + "/oauth-key.bin"
|
keyPath := tmpDir + "/oauth-key.bin"
|
||||||
@@ -44,20 +44,19 @@ func TestNewClientAppWithCustomScopes(t *testing.T) {
|
|||||||
baseURL := "http://localhost:5000"
|
baseURL := "http://localhost:5000"
|
||||||
scopes := []string{"atproto", "custom:scope"}
|
scopes := []string{"atproto", "custom:scope"}
|
||||||
|
|
||||||
clientApp, err := NewClientApp(baseURL, store, scopes, keyPath, "AT Container Registry")
|
app, err := NewAppWithScopes(baseURL, store, scopes, keyPath, "AT Container Registry")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewAppWithScopes() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if clientApp == nil {
|
if app == nil {
|
||||||
t.Fatal("Expected non-nil clientApp")
|
t.Fatal("Expected non-nil app")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify clientApp was created successfully
|
// Verify scopes are set in config
|
||||||
// (Note: indigo's oauth.ClientApp doesn't expose scopes directly,
|
config := app.GetConfig()
|
||||||
// but we can verify it was created without error)
|
if len(config.Scopes) != len(scopes) {
|
||||||
if clientApp.Dir == nil {
|
t.Errorf("Expected %d scopes, got %d", len(scopes), len(config.Scopes))
|
||||||
t.Error("Expected directory to be set")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,59 +121,3 @@ func TestScopesMatch(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Session Management (Refresher) Tests
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func TestNewRefresher(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
storePath := tmpDir + "/oauth-test.json"
|
|
||||||
|
|
||||||
store, err := NewFileStore(storePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewFileStore() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
refresher := NewRefresher(clientApp)
|
|
||||||
if refresher == nil {
|
|
||||||
t.Fatal("Expected non-nil refresher")
|
|
||||||
}
|
|
||||||
|
|
||||||
if refresher.clientApp == nil {
|
|
||||||
t.Error("Expected clientApp to be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefresher_SetUISessionStore(t *testing.T) {
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
storePath := tmpDir + "/oauth-test.json"
|
|
||||||
|
|
||||||
store, err := NewFileStore(storePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewFileStore() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
refresher := NewRefresher(clientApp)
|
|
||||||
|
|
||||||
// Test that SetUISessionStore doesn't panic with nil
|
|
||||||
// Full mock implementation requires implementing the interface
|
|
||||||
refresher.SetUISessionStore(nil)
|
|
||||||
|
|
||||||
// Verify nil is accepted
|
|
||||||
if refresher.uiSessionStore != nil {
|
|
||||||
t.Error("Expected UI session store to be nil after setting nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
type InteractiveResult struct {
|
type InteractiveResult struct {
|
||||||
SessionData *oauth.ClientSessionData
|
SessionData *oauth.ClientSessionData
|
||||||
Session *oauth.ClientSession
|
Session *oauth.ClientSession
|
||||||
ClientApp *oauth.ClientApp
|
App *App
|
||||||
}
|
}
|
||||||
|
|
||||||
// InteractiveFlowWithCallback runs an interactive OAuth flow with explicit callback handling
|
// InteractiveFlowWithCallback runs an interactive OAuth flow with explicit callback handling
|
||||||
@@ -32,16 +32,19 @@ func InteractiveFlowWithCallback(
|
|||||||
return nil, fmt.Errorf("failed to create OAuth store: %w", err)
|
return nil, fmt.Errorf("failed to create OAuth store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create OAuth client app with custom scopes (or defaults if nil)
|
// Create OAuth app with custom scopes (or defaults if nil)
|
||||||
// Interactive flows are typically for production use (credential helper, etc.)
|
// Interactive flows are typically for production use (credential helper, etc.)
|
||||||
|
// so we default to testMode=false
|
||||||
// For CLI tools, we use an empty keyPath since they're typically localhost (public client)
|
// For CLI tools, we use an empty keyPath since they're typically localhost (public client)
|
||||||
// or ephemeral sessions
|
// or ephemeral sessions
|
||||||
if scopes == nil {
|
var app *App
|
||||||
scopes = GetDefaultScopes("*")
|
if scopes != nil {
|
||||||
|
app, err = NewAppWithScopes(baseURL, store, scopes, "", "AT Container Registry")
|
||||||
|
} else {
|
||||||
|
app, err = NewApp(baseURL, store, "*", "", "AT Container Registry")
|
||||||
}
|
}
|
||||||
clientApp, err := NewClientApp(baseURL, store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create OAuth client app: %w", err)
|
return nil, fmt.Errorf("failed to create OAuth app: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channel to receive callback result
|
// Channel to receive callback result
|
||||||
@@ -51,7 +54,7 @@ func InteractiveFlowWithCallback(
|
|||||||
// Create callback handler
|
// Create callback handler
|
||||||
callbackHandler := func(w http.ResponseWriter, r *http.Request) {
|
callbackHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Process callback
|
// Process callback
|
||||||
sessionData, err := clientApp.ProcessCallback(r.Context(), r.URL.Query())
|
sessionData, err := app.ProcessCallback(r.Context(), r.URL.Query())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChan <- fmt.Errorf("failed to process callback: %w", err)
|
errorChan <- fmt.Errorf("failed to process callback: %w", err)
|
||||||
http.Error(w, "OAuth callback failed", http.StatusInternalServerError)
|
http.Error(w, "OAuth callback failed", http.StatusInternalServerError)
|
||||||
@@ -59,7 +62,7 @@ func InteractiveFlowWithCallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resume session
|
// Resume session
|
||||||
session, err := clientApp.ResumeSession(r.Context(), sessionData.AccountDID, sessionData.SessionID)
|
session, err := app.ResumeSession(r.Context(), sessionData.AccountDID, sessionData.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorChan <- fmt.Errorf("failed to resume session: %w", err)
|
errorChan <- fmt.Errorf("failed to resume session: %w", err)
|
||||||
http.Error(w, "Failed to resume session", http.StatusInternalServerError)
|
http.Error(w, "Failed to resume session", http.StatusInternalServerError)
|
||||||
@@ -70,7 +73,7 @@ func InteractiveFlowWithCallback(
|
|||||||
resultChan <- &InteractiveResult{
|
resultChan <- &InteractiveResult{
|
||||||
SessionData: sessionData,
|
SessionData: sessionData,
|
||||||
Session: session,
|
Session: session,
|
||||||
ClientApp: clientApp,
|
App: app,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return success to browser
|
// Return success to browser
|
||||||
@@ -84,7 +87,7 @@ func InteractiveFlowWithCallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start auth flow
|
// Start auth flow
|
||||||
authURL, err := clientApp.StartAuthFlow(ctx, handle)
|
authURL, err := app.StartAuthFlow(ctx, handle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to start auth flow: %w", err)
|
return nil, fmt.Errorf("failed to start auth flow: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
174
pkg/auth/oauth/refresher.go
Normal file
174
pkg/auth/oauth/refresher.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||||
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionCache represents a cached OAuth session
|
||||||
|
type SessionCache struct {
|
||||||
|
Session *oauth.ClientSession
|
||||||
|
SessionID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UISessionStore interface for managing UI sessions
|
||||||
|
// Shared between refresher and server
|
||||||
|
type UISessionStore interface {
|
||||||
|
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
|
||||||
|
DeleteByDID(did string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresher manages OAuth sessions and token refresh for AppView
|
||||||
|
type Refresher struct {
|
||||||
|
app *App
|
||||||
|
sessions map[string]*SessionCache // Key: DID string
|
||||||
|
mu sync.RWMutex
|
||||||
|
refreshLocks map[string]*sync.Mutex // Per-DID locks for refresh operations
|
||||||
|
refreshLockMu sync.Mutex // Protects refreshLocks map
|
||||||
|
uiSessionStore UISessionStore // For invalidating UI sessions on OAuth failures
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRefresher creates a new session refresher
|
||||||
|
func NewRefresher(app *App) *Refresher {
|
||||||
|
return &Refresher{
|
||||||
|
app: app,
|
||||||
|
sessions: make(map[string]*SessionCache),
|
||||||
|
refreshLocks: make(map[string]*sync.Mutex),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUISessionStore sets the UI session store for invalidating sessions on OAuth failures
|
||||||
|
func (r *Refresher) SetUISessionStore(store UISessionStore) {
|
||||||
|
r.uiSessionStore = store
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession gets a fresh OAuth session for a DID
|
||||||
|
// Returns cached session if still valid, otherwise resumes from store
|
||||||
|
func (r *Refresher) GetSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
|
||||||
|
// Check cache first (fast path)
|
||||||
|
r.mu.RLock()
|
||||||
|
cached, ok := r.sessions[did]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if ok && cached.Session != nil {
|
||||||
|
// Session cached, tokens will auto-refresh if needed
|
||||||
|
return cached.Session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session not cached, need to resume from store
|
||||||
|
// Get or create per-DID lock to prevent concurrent resume operations
|
||||||
|
r.refreshLockMu.Lock()
|
||||||
|
didLock, ok := r.refreshLocks[did]
|
||||||
|
if !ok {
|
||||||
|
didLock = &sync.Mutex{}
|
||||||
|
r.refreshLocks[did] = didLock
|
||||||
|
}
|
||||||
|
r.refreshLockMu.Unlock()
|
||||||
|
|
||||||
|
// Acquire DID-specific lock
|
||||||
|
didLock.Lock()
|
||||||
|
defer didLock.Unlock()
|
||||||
|
|
||||||
|
// Double-check cache after acquiring lock (another goroutine might have loaded it)
|
||||||
|
r.mu.RLock()
|
||||||
|
cached, ok = r.sessions[did]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
|
||||||
|
if ok && cached.Session != nil {
|
||||||
|
return cached.Session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually resume the session
|
||||||
|
return r.resumeSession(ctx, did)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resumeSession loads a session from storage and caches it
|
||||||
|
func (r *Refresher) resumeSession(ctx context.Context, did string) (*oauth.ClientSession, error) {
|
||||||
|
// Parse DID
|
||||||
|
accountDID, err := syntax.ParseDID(did)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse DID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the latest session for this DID from SQLite store
|
||||||
|
// The store must implement GetLatestSessionForDID (returns newest by updated_at)
|
||||||
|
type sessionGetter interface {
|
||||||
|
GetLatestSessionForDID(ctx context.Context, did string) (*oauth.ClientSessionData, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
getter, ok := r.app.clientApp.Store.(sessionGetter)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("store must implement GetLatestSessionForDID (SQLite store required)")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionData, sessionID, err := getter.GetLatestSessionForDID(ctx, did)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("no session found for DID: %s", did)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that session scopes match current desired scopes
|
||||||
|
desiredScopes := r.app.GetConfig().Scopes
|
||||||
|
if !ScopesMatch(sessionData.Scopes, desiredScopes) {
|
||||||
|
slog.Debug("Scope mismatch, deleting session",
|
||||||
|
"did", did,
|
||||||
|
"storedScopes", sessionData.Scopes,
|
||||||
|
"desiredScopes", desiredScopes)
|
||||||
|
|
||||||
|
// Delete the session from database since scopes have changed
|
||||||
|
if err := r.app.clientApp.Store.DeleteSession(ctx, accountDID, sessionID); err != nil {
|
||||||
|
slog.Warn("Failed to delete session with mismatched scopes", "error", err, "did", did)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("OAuth scopes changed, re-authentication required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resume session
|
||||||
|
session, err := r.app.ResumeSession(ctx, accountDID, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to resume session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the session
|
||||||
|
r.mu.Lock()
|
||||||
|
r.sessions[did] = &SessionCache{
|
||||||
|
Session: session,
|
||||||
|
SessionID: sessionID,
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateSession removes a cached session for a DID
|
||||||
|
// This is useful when a new OAuth flow creates a fresh session or when OAuth refresh fails
|
||||||
|
// Also invalidates any UI sessions for this DID to force re-authentication
|
||||||
|
func (r *Refresher) InvalidateSession(did string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
delete(r.sessions, did)
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
// Also delete UI sessions to force user to re-authenticate
|
||||||
|
if r.uiSessionStore != nil {
|
||||||
|
r.uiSessionStore.DeleteByDID(did)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSessionID returns the sessionID for a cached session
|
||||||
|
// Returns empty string if session not cached
|
||||||
|
func (r *Refresher) GetSessionID(did string) string {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
cached, ok := r.sessions[did]
|
||||||
|
if !ok || cached == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return cached.SessionID
|
||||||
|
}
|
||||||
66
pkg/auth/oauth/refresher_test.go
Normal file
66
pkg/auth/oauth/refresher_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewRefresher(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
storePath := tmpDir + "/oauth-test.json"
|
||||||
|
|
||||||
|
store, err := NewFileStore(storePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refresher := NewRefresher(app)
|
||||||
|
if refresher == nil {
|
||||||
|
t.Fatal("Expected non-nil refresher")
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresher.app == nil {
|
||||||
|
t.Error("Expected app to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresher.sessions == nil {
|
||||||
|
t.Error("Expected sessions map to be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresher.refreshLocks == nil {
|
||||||
|
t.Error("Expected refreshLocks map to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefresher_SetUISessionStore(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
storePath := tmpDir + "/oauth-test.json"
|
||||||
|
|
||||||
|
store, err := NewFileStore(storePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refresher := NewRefresher(app)
|
||||||
|
|
||||||
|
// Test that SetUISessionStore doesn't panic with nil
|
||||||
|
// Full mock implementation requires implementing the interface
|
||||||
|
refresher.SetUISessionStore(nil)
|
||||||
|
|
||||||
|
// Verify nil is accepted
|
||||||
|
if refresher.uiSessionStore != nil {
|
||||||
|
t.Error("Expected UI session store to be nil after setting nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Full session management tests will be added in comprehensive implementation
|
||||||
|
// Those tests will require mocking OAuth sessions and testing cache behavior
|
||||||
@@ -10,11 +10,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"atcr.io/pkg/atproto"
|
"atcr.io/pkg/atproto"
|
||||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UISessionStore is the interface for UI session management
|
// UISessionStore is the interface for UI session management
|
||||||
// UISessionStore is defined in client.go (session management section)
|
// UISessionStore is defined in refresher.go to avoid duplication
|
||||||
|
|
||||||
// UserStore is the interface for user management
|
// UserStore is the interface for user management
|
||||||
type UserStore interface {
|
type UserStore interface {
|
||||||
@@ -29,16 +28,16 @@ type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, sessio
|
|||||||
|
|
||||||
// Server handles OAuth authorization for the AppView
|
// Server handles OAuth authorization for the AppView
|
||||||
type Server struct {
|
type Server struct {
|
||||||
clientApp *oauth.ClientApp
|
app *App
|
||||||
refresher *Refresher
|
refresher *Refresher
|
||||||
uiSessionStore UISessionStore
|
uiSessionStore UISessionStore
|
||||||
postAuthCallback PostAuthCallback
|
postAuthCallback PostAuthCallback
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new OAuth server
|
// NewServer creates a new OAuth server
|
||||||
func NewServer(clientApp *oauth.ClientApp) *Server {
|
func NewServer(app *App) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
clientApp: clientApp,
|
app: app,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,7 +74,7 @@ func (s *Server) ServeAuthorize(w http.ResponseWriter, r *http.Request) {
|
|||||||
slog.Debug("Starting OAuth flow", "handle", handle)
|
slog.Debug("Starting OAuth flow", "handle", handle)
|
||||||
|
|
||||||
// Start auth flow via indigo
|
// Start auth flow via indigo
|
||||||
authURL, err := s.clientApp.StartAuthFlow(r.Context(), handle)
|
authURL, err := s.app.StartAuthFlow(r.Context(), handle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to start auth flow", "error", err, "handle", handle)
|
slog.Error("Failed to start auth flow", "error", err, "handle", handle)
|
||||||
|
|
||||||
@@ -112,7 +111,7 @@ func (s *Server) ServeCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process OAuth callback via indigo (handles state validation internally)
|
// Process OAuth callback via indigo (handles state validation internally)
|
||||||
sessionData, err := s.clientApp.ProcessCallback(r.Context(), r.URL.Query())
|
sessionData, err := s.app.ProcessCallback(r.Context(), r.URL.Query())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.renderError(w, fmt.Sprintf("Failed to process OAuth callback: %v", err))
|
s.renderError(w, fmt.Sprintf("Failed to process OAuth callback: %v", err))
|
||||||
return
|
return
|
||||||
@@ -123,20 +122,10 @@ func (s *Server) ServeCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
slog.Debug("OAuth callback successful", "did", did, "sessionID", sessionID)
|
slog.Debug("OAuth callback successful", "did", did, "sessionID", sessionID)
|
||||||
|
|
||||||
// Clean up old OAuth sessions for this DID BEFORE invalidating cache
|
// Invalidate cached session (if any) since we have a new session with new tokens
|
||||||
// This prevents accumulation of stale sessions with expired refresh tokens
|
if s.refresher != nil {
|
||||||
// Order matters: delete from DB first, then invalidate cache, so when cache reloads
|
s.refresher.InvalidateSession(did)
|
||||||
// it will only find the new session
|
slog.Debug("Invalidated cached session after creating new session", "did", did)
|
||||||
type sessionCleaner interface {
|
|
||||||
DeleteOldSessionsForDID(ctx context.Context, did string, keepSessionID string) error
|
|
||||||
}
|
|
||||||
if cleaner, ok := s.clientApp.Store.(sessionCleaner); ok {
|
|
||||||
if err := cleaner.DeleteOldSessionsForDID(r.Context(), did, sessionID); err != nil {
|
|
||||||
slog.Warn("Failed to clean up old OAuth sessions", "did", did, "error", err)
|
|
||||||
// Non-fatal - log and continue
|
|
||||||
} else {
|
|
||||||
slog.Debug("Cleaned up old OAuth sessions", "did", did, "kept", sessionID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up identity (resolve DID to handle)
|
// Look up identity (resolve DID to handle)
|
||||||
|
|||||||
@@ -19,19 +19,18 @@ func TestNewServer(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
if server == nil {
|
if server == nil {
|
||||||
t.Fatal("Expected non-nil server")
|
t.Fatal("Expected non-nil server")
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.clientApp == nil {
|
if server.app == nil {
|
||||||
t.Error("Expected clientApp to be set")
|
t.Error("Expected app to be set")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,14 +43,13 @@ func TestServer_SetRefresher(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
refresher := NewRefresher(clientApp)
|
refresher := NewRefresher(app)
|
||||||
|
|
||||||
server.SetRefresher(refresher)
|
server.SetRefresher(refresher)
|
||||||
if server.refresher == nil {
|
if server.refresher == nil {
|
||||||
@@ -68,13 +66,12 @@ func TestServer_SetPostAuthCallback(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
// Set callback with correct signature
|
// Set callback with correct signature
|
||||||
server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error {
|
server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error {
|
||||||
@@ -95,13 +92,12 @@ func TestServer_SetUISessionStore(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
mockStore := &mockUISessionStore{}
|
mockStore := &mockUISessionStore{}
|
||||||
|
|
||||||
server.SetUISessionStore(mockStore)
|
server.SetUISessionStore(mockStore)
|
||||||
@@ -159,13 +155,12 @@ func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -187,13 +182,12 @@ func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil)
|
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -217,13 +211,12 @@ func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil)
|
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -245,13 +238,12 @@ func TestServer_ServeCallback_OAuthError(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -278,13 +270,12 @@ func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
callbackInvoked := false
|
callbackInvoked := false
|
||||||
server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error {
|
server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error {
|
||||||
@@ -323,13 +314,12 @@ func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
server.SetUISessionStore(uiStore)
|
server.SetUISessionStore(uiStore)
|
||||||
|
|
||||||
// Verify UI session store is set
|
// Verify UI session store is set
|
||||||
@@ -353,13 +343,12 @@ func TestServer_RenderError(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
server.renderError(w, "Test error message")
|
server.renderError(w, "Test error message")
|
||||||
@@ -388,13 +377,12 @@ func TestServer_RenderRedirectToSettings(t *testing.T) {
|
|||||||
t.Fatalf("NewFileStore() error = %v", err)
|
t.Fatalf("NewFileStore() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes := GetDefaultScopes("*")
|
app, err := NewApp("http://localhost:5000", store, "*", "", "AT Container Registry")
|
||||||
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClientApp() error = %v", err)
|
t.Fatalf("NewApp() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer(clientApp)
|
server := NewServer(app)
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
server.renderRedirectToSettings(w, "alice.bsky.social")
|
server.renderRedirectToSettings(w, "alice.bsky.social")
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ func GetOrFetchServiceToken(
|
|||||||
|
|
||||||
session, err := refresher.GetSession(ctx, did)
|
session, err := refresher.GetSession(ctx, did)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// OAuth session unavailable - fail
|
// OAuth session unavailable - invalidate and fail
|
||||||
|
refresher.InvalidateSession(did)
|
||||||
InvalidateServiceToken(did, holdDID)
|
InvalidateServiceToken(did, holdDID)
|
||||||
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
return "", fmt.Errorf("failed to get OAuth session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -72,15 +73,17 @@ func GetOrFetchServiceToken(
|
|||||||
// Use OAuth session to authenticate to PDS (with DPoP)
|
// Use OAuth session to authenticate to PDS (with DPoP)
|
||||||
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
resp, err := session.DoWithAuth(session.Client, req, "com.atproto.server.getServiceAuth")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Auth error - may indicate expired tokens or corrupted session
|
// Invalidate session on auth errors (may indicate corrupted session or expired tokens)
|
||||||
|
refresher.InvalidateSession(did)
|
||||||
InvalidateServiceToken(did, holdDID)
|
InvalidateServiceToken(did, holdDID)
|
||||||
return "", fmt.Errorf("OAuth validation failed: %w", err)
|
return "", fmt.Errorf("OAuth validation failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
// Service auth failed
|
// Invalidate session on auth failures
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
refresher.InvalidateSession(did)
|
||||||
InvalidateServiceToken(did, holdDID)
|
InvalidateServiceToken(did, holdDID)
|
||||||
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
return "", fmt.Errorf("service auth failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user