Files
at-container-registry/pkg/auth/hold_remote.go

566 lines
17 KiB
Go

package auth
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"sync"
"time"
"atcr.io/pkg/atproto"
)
// RemoteHoldAuthorizer queries a hold's PDS via XRPC endpoints
// Used by AppView to authorize access to remote holds
// Implements caching for captain records to reduce XRPC calls
type RemoteHoldAuthorizer struct {
db *sql.DB
httpClient *http.Client
cacheTTL time.Duration // TTL for captain record cache
recentDenials sync.Map // In-memory cache for first denials
stopCleanup chan struct{} // Signal to stop cleanup goroutine
testMode bool // If true, use HTTP for local DIDs
firstDenialBackoff time.Duration // Backoff duration for first denial (default: 10s)
cleanupInterval time.Duration // Cleanup goroutine interval (default: 10s)
cleanupGracePeriod time.Duration // Grace period before cleanup (default: 5s)
dbBackoffDurations []time.Duration // Backoff durations for DB denials (default: [1m, 5m, 15m, 1h])
}
// denialEntry stores timestamp for in-memory first denials
type denialEntry struct {
timestamp time.Time
}
// NewRemoteHoldAuthorizer creates a new remote authorizer for AppView with production defaults
func NewRemoteHoldAuthorizer(db *sql.DB, testMode bool) HoldAuthorizer {
return NewRemoteHoldAuthorizerWithBackoffs(db, testMode,
10*time.Second, // firstDenialBackoff
10*time.Second, // cleanupInterval
5*time.Second, // cleanupGracePeriod
[]time.Duration{ // dbBackoffDurations
1 * time.Minute,
5 * time.Minute,
15 * time.Minute,
60 * time.Minute,
},
)
}
// NewRemoteHoldAuthorizerWithBackoffs creates a new remote authorizer with custom backoff durations
// Used for testing to avoid long sleeps
func NewRemoteHoldAuthorizerWithBackoffs(db *sql.DB, testMode bool, firstDenialBackoff, cleanupInterval, cleanupGracePeriod time.Duration, dbBackoffDurations []time.Duration) HoldAuthorizer {
a := &RemoteHoldAuthorizer{
db: db,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
cacheTTL: 1 * time.Hour, // 1 hour cache TTL
stopCleanup: make(chan struct{}),
testMode: testMode,
firstDenialBackoff: firstDenialBackoff,
cleanupInterval: cleanupInterval,
cleanupGracePeriod: cleanupGracePeriod,
dbBackoffDurations: dbBackoffDurations,
}
// Start cleanup goroutine for in-memory denials
go a.cleanupRecentDenials()
return a
}
// cleanupRecentDenials runs periodically to remove expired first-denial entries
func (a *RemoteHoldAuthorizer) cleanupRecentDenials() {
ticker := time.NewTicker(a.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
a.recentDenials.Range(func(key, value any) bool {
entry := value.(denialEntry)
// Remove entries older than backoff + grace period
if now.Sub(entry.timestamp) > a.firstDenialBackoff+a.cleanupGracePeriod {
a.recentDenials.Delete(key)
}
return true
})
case <-a.stopCleanup:
return
}
}
}
// GetCaptainRecord retrieves a captain record with caching
// 1. Check database cache
// 2. If cache miss or expired, query hold's XRPC endpoint
// 3. Update cache
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Try cache first
if a.db != nil {
cached, err := a.getCachedCaptainRecord(holdDID)
if err == nil && cached != nil {
// Cache hit - check if still valid
if time.Since(cached.UpdatedAt) < a.cacheTTL {
return cached.CaptainRecord, nil
}
// Cache expired - continue to fetch fresh data
}
}
// Cache miss or expired - query XRPC endpoint
record, err := a.fetchCaptainRecordFromXRPC(ctx, holdDID)
if err != nil {
return nil, err
}
// Update cache
if a.db != nil {
if err := a.setCachedCaptainRecord(holdDID, record); err != nil {
// Log error but don't fail - caching is best-effort
slog.Warn("Failed to cache captain record", "error", err, "holdDID", holdDID)
}
}
return record, nil
}
// captainRecordWithMeta includes UpdatedAt for cache management
type captainRecordWithMeta struct {
*atproto.CaptainRecord
UpdatedAt time.Time
}
// getCachedCaptainRecord retrieves a captain record from database cache
func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainRecordWithMeta, error) {
query := `
SELECT owner_did, public, allow_all_crew, deployed_at, region, provider, updated_at
FROM hold_captain_records
WHERE hold_did = ?
`
var record atproto.CaptainRecord
var deployedAt, region, provider sql.NullString
var updatedAt time.Time
err := a.db.QueryRow(query, holdDID).Scan(
&record.Owner,
&record.Public,
&record.AllowAllCrew,
&deployedAt,
&region,
&provider,
&updatedAt,
)
if err == sql.ErrNoRows {
return nil, nil // Cache miss
}
if err != nil {
return nil, fmt.Errorf("cache query failed: %w", err)
}
// Handle nullable fields
if deployedAt.Valid {
record.DeployedAt = deployedAt.String
}
if region.Valid {
record.Region = region.String
}
if provider.Valid {
record.Provider = provider.String
}
return &captainRecordWithMeta{
CaptainRecord: &record,
UpdatedAt: updatedAt,
}, nil
}
// setCachedCaptainRecord stores a captain record in database cache
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.CaptainRecord) error {
query := `
INSERT INTO hold_captain_records (
hold_did, owner_did, public, allow_all_crew,
deployed_at, region, provider, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(hold_did) DO UPDATE SET
owner_did = excluded.owner_did,
public = excluded.public,
allow_all_crew = excluded.allow_all_crew,
deployed_at = excluded.deployed_at,
region = excluded.region,
provider = excluded.provider,
updated_at = excluded.updated_at
`
_, err := a.db.Exec(query,
holdDID,
record.Owner,
record.Public,
record.AllowAllCrew,
nullString(record.DeployedAt),
nullString(record.Region),
nullString(record.Provider),
time.Now(),
)
return err
}
// fetchCaptainRecordFromXRPC queries the hold's XRPC endpoint for captain record
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Resolve DID to URL
holdURL := atproto.ResolveHoldURL(holdDID)
// Build XRPC request URL
// GET /xrpc/com.atproto.repo.getRecord?repo={did}&collection=io.atcr.hold.captain&rkey=self
xrpcURL := fmt.Sprintf("%s%s?repo=%s&collection=%s&rkey=self",
holdURL, atproto.RepoGetRecord, url.QueryEscape(holdDID), url.QueryEscape(atproto.CaptainCollection))
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
if err != nil {
return nil, err
}
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("XRPC request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("XRPC request failed: status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var xrpcResp struct {
URI string `json:"uri"`
CID string `json:"cid"`
Value struct {
Type string `json:"$type"`
Owner string `json:"owner"`
Public bool `json:"public"`
AllowAllCrew bool `json:"allowAllCrew"`
DeployedAt string `json:"deployedAt"`
Region string `json:"region,omitempty"`
Provider string `json:"provider,omitempty"`
} `json:"value"`
}
if err := json.NewDecoder(resp.Body).Decode(&xrpcResp); err != nil {
return nil, fmt.Errorf("failed to decode XRPC response: %w", err)
}
// Convert to our type
record := &atproto.CaptainRecord{
Type: atproto.CaptainCollection,
Owner: xrpcResp.Value.Owner,
Public: xrpcResp.Value.Public,
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
DeployedAt: xrpcResp.Value.DeployedAt,
Region: xrpcResp.Value.Region,
Provider: xrpcResp.Value.Provider,
}
return record, nil
}
// IsCrewMember checks if userDID is a crew member with caching
// 1. Check approval cache (15min TTL)
// 2. Check denial cache with exponential backoff
// 3. If cache miss, query XRPC endpoint and update cache
func (a *RemoteHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
// Skip caching if no database
if a.db == nil {
return a.isCrewMemberNoCache(ctx, holdDID, userDID)
}
// Check approval cache first (15min TTL)
if approved, err := a.getCachedApproval(holdDID, userDID); err == nil && approved {
slog.Debug("Using cached crew approval", "holdDID", holdDID, "userDID", userDID)
return true, nil
}
// Check denial cache with backoff
if blocked, err := a.isBlockedByDenialBackoff(holdDID, userDID); err == nil && blocked {
// Still in backoff period - don't query again
slog.Debug("Blocked by denial backoff cache", "holdDID", holdDID, "userDID", userDID)
return false, nil
}
// Cache miss or expired - query XRPC endpoint
slog.Debug("Crew membership cache miss, querying hold", "holdDID", holdDID, "userDID", userDID)
isCrew, err := a.isCrewMemberNoCache(ctx, holdDID, userDID)
if err != nil {
slog.Warn("Crew membership query error", "error", err, "holdDID", holdDID, "userDID", userDID)
return false, err
}
// Update cache based on result
if isCrew {
// Cache approval for 15 minutes
slog.Debug("Crew membership approved, caching for 15min", "holdDID", holdDID, "userDID", userDID)
_ = a.cacheApproval(holdDID, userDID, 15*time.Minute)
} else {
// Cache denial with exponential backoff
slog.Debug("Crew membership denied, caching with backoff", "holdDID", holdDID, "userDID", userDID)
_ = a.cacheDenial(holdDID, userDID)
}
return isCrew, nil
}
// isCrewMemberNoCache queries XRPC without caching (internal helper)
func (a *RemoteHoldAuthorizer) isCrewMemberNoCache(ctx context.Context, holdDID, userDID string) (bool, error) {
// Resolve DID to URL
holdURL := atproto.ResolveHoldURL(holdDID)
// Build XRPC request URL
// GET /xrpc/com.atproto.repo.listRecords?repo={did}&collection=io.atcr.hold.crew
xrpcURL := fmt.Sprintf("%s%s?repo=%s&collection=%s",
holdURL, atproto.RepoListRecords, url.QueryEscape(holdDID), url.QueryEscape(atproto.CrewCollection))
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
if err != nil {
return false, err
}
resp, err := a.httpClient.Do(req)
if err != nil {
return false, fmt.Errorf("XRPC request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return false, fmt.Errorf("XRPC request failed: status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var xrpcResp struct {
Records []struct {
URI string `json:"uri"`
CID string `json:"cid"`
Value struct {
Type string `json:"$type"`
Member string `json:"member"`
Role string `json:"role"`
Permissions []string `json:"permissions"`
AddedAt string `json:"addedAt"`
} `json:"value"`
} `json:"records"`
}
if err := json.NewDecoder(resp.Body).Decode(&xrpcResp); err != nil {
return false, fmt.Errorf("failed to decode XRPC response: %w", err)
}
// Check if userDID is in the crew list
for _, record := range xrpcResp.Records {
if record.Value.Member == userDID {
// TODO: Check expiration if set
return true, nil
}
}
return false, nil
}
// CheckReadAccess implements read authorization using shared logic
func (a *RemoteHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
captain, err := a.GetCaptainRecord(ctx, holdDID)
if err != nil {
return false, err
}
return CheckReadAccessWithCaptain(captain, userDID), nil
}
// CheckWriteAccess implements write authorization using shared logic
func (a *RemoteHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
captain, err := a.GetCaptainRecord(ctx, holdDID)
if err != nil {
return false, err
}
isCrew, err := a.IsCrewMember(ctx, holdDID, userDID)
if err != nil {
return false, err
}
return CheckWriteAccessWithCaptain(captain, userDID, isCrew), nil
}
// nullString converts a string to sql.NullString
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{Valid: false}
}
return sql.NullString{String: s, Valid: true}
}
// getCachedApproval checks if user has a cached crew approval
func (a *RemoteHoldAuthorizer) getCachedApproval(holdDID, userDID string) (bool, error) {
query := `
SELECT expires_at
FROM hold_crew_approvals
WHERE hold_did = ? AND user_did = ?
`
var expiresAt time.Time
err := a.db.QueryRow(query, holdDID, userDID).Scan(&expiresAt)
if err == sql.ErrNoRows {
return false, nil // Cache miss
}
if err != nil {
return false, err
}
// Check if approval has expired
if time.Now().After(expiresAt) {
// Expired - clean up
_ = a.deleteCachedApproval(holdDID, userDID)
return false, nil
}
return true, nil
}
// cacheApproval stores a crew approval with TTL
func (a *RemoteHoldAuthorizer) cacheApproval(holdDID, userDID string, ttl time.Duration) error {
query := `
INSERT INTO hold_crew_approvals (hold_did, user_did, approved_at, expires_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(hold_did, user_did) DO UPDATE SET
approved_at = excluded.approved_at,
expires_at = excluded.expires_at
`
now := time.Now()
expiresAt := now.Add(ttl)
_, err := a.db.Exec(query, holdDID, userDID, now, expiresAt)
return err
}
// deleteCachedApproval removes an expired approval
func (a *RemoteHoldAuthorizer) deleteCachedApproval(holdDID, userDID string) error {
query := `DELETE FROM hold_crew_approvals WHERE hold_did = ? AND user_did = ?`
_, err := a.db.Exec(query, holdDID, userDID)
return err
}
// isBlockedByDenialBackoff checks if user is in denial backoff period
// Checks in-memory cache first (for 10s first denials), then DB (for longer backoffs)
func (a *RemoteHoldAuthorizer) isBlockedByDenialBackoff(holdDID, userDID string) (bool, error) {
// Check in-memory cache first (first denials with configurable backoff)
key := fmt.Sprintf("%s:%s", holdDID, userDID)
if val, ok := a.recentDenials.Load(key); ok {
entry := val.(denialEntry)
// Check if still within first denial backoff period
if time.Since(entry.timestamp) < a.firstDenialBackoff {
return true, nil // Still blocked by in-memory first denial
}
}
// Check database for longer backoffs (second+ denials)
query := `
SELECT next_retry_at
FROM hold_crew_denials
WHERE hold_did = ? AND user_did = ?
`
var nextRetryAt time.Time
err := a.db.QueryRow(query, holdDID, userDID).Scan(&nextRetryAt)
if err == sql.ErrNoRows {
return false, nil // No denial record
}
if err != nil {
return false, err
}
// Check if still in backoff period
if time.Now().Before(nextRetryAt) {
return true, nil // Still blocked
}
// Backoff period expired - can retry
return false, nil
}
// cacheDenial stores or updates a denial with exponential backoff
// First denial: in-memory only (configurable backoff, default 10s)
// Second+ denial: database with exponential backoff (configurable, default 1m/5m/15m/1h)
func (a *RemoteHoldAuthorizer) cacheDenial(holdDID, userDID string) error {
key := fmt.Sprintf("%s:%s", holdDID, userDID)
// Check if this is a first denial (not in memory, not in DB)
_, inMemory := a.recentDenials.Load(key)
var denialCount int
query := `SELECT denial_count FROM hold_crew_denials WHERE hold_did = ? AND user_did = ?`
err := a.db.QueryRow(query, holdDID, userDID).Scan(&denialCount)
inDB := err != sql.ErrNoRows
if err != nil && err != sql.ErrNoRows {
return err
}
// If not in memory and not in DB, this is the first denial
if !inMemory && !inDB {
// First denial: store only in memory with configurable backoff
a.recentDenials.Store(key, denialEntry{timestamp: time.Now()})
return nil
}
// Second+ denial: persist to database with exponential backoff
denialCount++
backoff := a.getBackoffDuration(denialCount)
now := time.Now()
nextRetry := now.Add(backoff)
// Upsert denial record
upsertQuery := `
INSERT INTO hold_crew_denials (hold_did, user_did, denial_count, next_retry_at, last_denied_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(hold_did, user_did) DO UPDATE SET
denial_count = excluded.denial_count,
next_retry_at = excluded.next_retry_at,
last_denied_at = excluded.last_denied_at
`
_, err = a.db.Exec(upsertQuery, holdDID, userDID, denialCount, nextRetry, now)
// Remove from in-memory cache since we're now tracking in DB
a.recentDenials.Delete(key)
return err
}
// getBackoffDuration returns the backoff duration based on denial count
// Note: First denial is in-memory only and not tracked by this function
// This function handles second+ denials using configurable durations
func (a *RemoteHoldAuthorizer) getBackoffDuration(denialCount int) time.Duration {
backoffs := a.dbBackoffDurations
idx := denialCount - 1
if idx >= len(backoffs) {
idx = len(backoffs) - 1
}
return backoffs[idx]
}