566 lines
17 KiB
Go
566 lines
17 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"atcr.io/pkg/atproto"
|
|
)
|
|
|
|
// RemoteHoldAuthorizer queries a hold's PDS via XRPC endpoints
|
|
// Used by AppView to authorize access to remote holds
|
|
// Implements caching for captain records to reduce XRPC calls
|
|
type RemoteHoldAuthorizer struct {
|
|
db *sql.DB
|
|
httpClient *http.Client
|
|
cacheTTL time.Duration // TTL for captain record cache
|
|
recentDenials sync.Map // In-memory cache for first denials
|
|
stopCleanup chan struct{} // Signal to stop cleanup goroutine
|
|
testMode bool // If true, use HTTP for local DIDs
|
|
firstDenialBackoff time.Duration // Backoff duration for first denial (default: 10s)
|
|
cleanupInterval time.Duration // Cleanup goroutine interval (default: 10s)
|
|
cleanupGracePeriod time.Duration // Grace period before cleanup (default: 5s)
|
|
dbBackoffDurations []time.Duration // Backoff durations for DB denials (default: [1m, 5m, 15m, 1h])
|
|
}
|
|
|
|
// denialEntry stores timestamp for in-memory first denials
|
|
type denialEntry struct {
|
|
timestamp time.Time
|
|
}
|
|
|
|
// NewRemoteHoldAuthorizer creates a new remote authorizer for AppView with production defaults
|
|
func NewRemoteHoldAuthorizer(db *sql.DB, testMode bool) HoldAuthorizer {
|
|
return NewRemoteHoldAuthorizerWithBackoffs(db, testMode,
|
|
10*time.Second, // firstDenialBackoff
|
|
10*time.Second, // cleanupInterval
|
|
5*time.Second, // cleanupGracePeriod
|
|
[]time.Duration{ // dbBackoffDurations
|
|
1 * time.Minute,
|
|
5 * time.Minute,
|
|
15 * time.Minute,
|
|
60 * time.Minute,
|
|
},
|
|
)
|
|
}
|
|
|
|
// NewRemoteHoldAuthorizerWithBackoffs creates a new remote authorizer with custom backoff durations
|
|
// Used for testing to avoid long sleeps
|
|
func NewRemoteHoldAuthorizerWithBackoffs(db *sql.DB, testMode bool, firstDenialBackoff, cleanupInterval, cleanupGracePeriod time.Duration, dbBackoffDurations []time.Duration) HoldAuthorizer {
|
|
a := &RemoteHoldAuthorizer{
|
|
db: db,
|
|
httpClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
cacheTTL: 1 * time.Hour, // 1 hour cache TTL
|
|
stopCleanup: make(chan struct{}),
|
|
testMode: testMode,
|
|
firstDenialBackoff: firstDenialBackoff,
|
|
cleanupInterval: cleanupInterval,
|
|
cleanupGracePeriod: cleanupGracePeriod,
|
|
dbBackoffDurations: dbBackoffDurations,
|
|
}
|
|
|
|
// Start cleanup goroutine for in-memory denials
|
|
go a.cleanupRecentDenials()
|
|
|
|
return a
|
|
}
|
|
|
|
// cleanupRecentDenials runs periodically to remove expired first-denial entries
|
|
func (a *RemoteHoldAuthorizer) cleanupRecentDenials() {
|
|
ticker := time.NewTicker(a.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
now := time.Now()
|
|
a.recentDenials.Range(func(key, value any) bool {
|
|
entry := value.(denialEntry)
|
|
// Remove entries older than backoff + grace period
|
|
if now.Sub(entry.timestamp) > a.firstDenialBackoff+a.cleanupGracePeriod {
|
|
a.recentDenials.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
case <-a.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetCaptainRecord retrieves a captain record with caching
|
|
// 1. Check database cache
|
|
// 2. If cache miss or expired, query hold's XRPC endpoint
|
|
// 3. Update cache
|
|
func (a *RemoteHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
|
// Try cache first
|
|
if a.db != nil {
|
|
cached, err := a.getCachedCaptainRecord(holdDID)
|
|
if err == nil && cached != nil {
|
|
// Cache hit - check if still valid
|
|
if time.Since(cached.UpdatedAt) < a.cacheTTL {
|
|
return cached.CaptainRecord, nil
|
|
}
|
|
// Cache expired - continue to fetch fresh data
|
|
}
|
|
}
|
|
|
|
// Cache miss or expired - query XRPC endpoint
|
|
record, err := a.fetchCaptainRecordFromXRPC(ctx, holdDID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Update cache
|
|
if a.db != nil {
|
|
if err := a.setCachedCaptainRecord(holdDID, record); err != nil {
|
|
// Log error but don't fail - caching is best-effort
|
|
slog.Warn("Failed to cache captain record", "error", err, "holdDID", holdDID)
|
|
}
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
// captainRecordWithMeta includes UpdatedAt for cache management
|
|
type captainRecordWithMeta struct {
|
|
*atproto.CaptainRecord
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
// getCachedCaptainRecord retrieves a captain record from database cache
|
|
func (a *RemoteHoldAuthorizer) getCachedCaptainRecord(holdDID string) (*captainRecordWithMeta, error) {
|
|
query := `
|
|
SELECT owner_did, public, allow_all_crew, deployed_at, region, provider, updated_at
|
|
FROM hold_captain_records
|
|
WHERE hold_did = ?
|
|
`
|
|
|
|
var record atproto.CaptainRecord
|
|
var deployedAt, region, provider sql.NullString
|
|
var updatedAt time.Time
|
|
|
|
err := a.db.QueryRow(query, holdDID).Scan(
|
|
&record.Owner,
|
|
&record.Public,
|
|
&record.AllowAllCrew,
|
|
&deployedAt,
|
|
®ion,
|
|
&provider,
|
|
&updatedAt,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil // Cache miss
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cache query failed: %w", err)
|
|
}
|
|
|
|
// Handle nullable fields
|
|
if deployedAt.Valid {
|
|
record.DeployedAt = deployedAt.String
|
|
}
|
|
if region.Valid {
|
|
record.Region = region.String
|
|
}
|
|
if provider.Valid {
|
|
record.Provider = provider.String
|
|
}
|
|
|
|
return &captainRecordWithMeta{
|
|
CaptainRecord: &record,
|
|
UpdatedAt: updatedAt,
|
|
}, nil
|
|
}
|
|
|
|
// setCachedCaptainRecord stores a captain record in database cache
|
|
func (a *RemoteHoldAuthorizer) setCachedCaptainRecord(holdDID string, record *atproto.CaptainRecord) error {
|
|
query := `
|
|
INSERT INTO hold_captain_records (
|
|
hold_did, owner_did, public, allow_all_crew,
|
|
deployed_at, region, provider, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(hold_did) DO UPDATE SET
|
|
owner_did = excluded.owner_did,
|
|
public = excluded.public,
|
|
allow_all_crew = excluded.allow_all_crew,
|
|
deployed_at = excluded.deployed_at,
|
|
region = excluded.region,
|
|
provider = excluded.provider,
|
|
updated_at = excluded.updated_at
|
|
`
|
|
|
|
_, err := a.db.Exec(query,
|
|
holdDID,
|
|
record.Owner,
|
|
record.Public,
|
|
record.AllowAllCrew,
|
|
nullString(record.DeployedAt),
|
|
nullString(record.Region),
|
|
nullString(record.Provider),
|
|
time.Now(),
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// fetchCaptainRecordFromXRPC queries the hold's XRPC endpoint for captain record
|
|
func (a *RemoteHoldAuthorizer) fetchCaptainRecordFromXRPC(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
|
|
// Resolve DID to URL
|
|
holdURL := atproto.ResolveHoldURL(holdDID)
|
|
|
|
// Build XRPC request URL
|
|
// GET /xrpc/com.atproto.repo.getRecord?repo={did}&collection=io.atcr.hold.captain&rkey=self
|
|
xrpcURL := fmt.Sprintf("%s%s?repo=%s&collection=%s&rkey=self",
|
|
holdURL, atproto.RepoGetRecord, url.QueryEscape(holdDID), url.QueryEscape(atproto.CaptainCollection))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := a.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("XRPC request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("XRPC request failed: status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Parse response
|
|
var xrpcResp struct {
|
|
URI string `json:"uri"`
|
|
CID string `json:"cid"`
|
|
Value struct {
|
|
Type string `json:"$type"`
|
|
Owner string `json:"owner"`
|
|
Public bool `json:"public"`
|
|
AllowAllCrew bool `json:"allowAllCrew"`
|
|
DeployedAt string `json:"deployedAt"`
|
|
Region string `json:"region,omitempty"`
|
|
Provider string `json:"provider,omitempty"`
|
|
} `json:"value"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&xrpcResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode XRPC response: %w", err)
|
|
}
|
|
|
|
// Convert to our type
|
|
record := &atproto.CaptainRecord{
|
|
Type: atproto.CaptainCollection,
|
|
Owner: xrpcResp.Value.Owner,
|
|
Public: xrpcResp.Value.Public,
|
|
AllowAllCrew: xrpcResp.Value.AllowAllCrew,
|
|
DeployedAt: xrpcResp.Value.DeployedAt,
|
|
Region: xrpcResp.Value.Region,
|
|
Provider: xrpcResp.Value.Provider,
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
// IsCrewMember checks if userDID is a crew member with caching
|
|
// 1. Check approval cache (15min TTL)
|
|
// 2. Check denial cache with exponential backoff
|
|
// 3. If cache miss, query XRPC endpoint and update cache
|
|
func (a *RemoteHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
|
|
// Skip caching if no database
|
|
if a.db == nil {
|
|
return a.isCrewMemberNoCache(ctx, holdDID, userDID)
|
|
}
|
|
|
|
// Check approval cache first (15min TTL)
|
|
if approved, err := a.getCachedApproval(holdDID, userDID); err == nil && approved {
|
|
slog.Debug("Using cached crew approval", "holdDID", holdDID, "userDID", userDID)
|
|
return true, nil
|
|
}
|
|
|
|
// Check denial cache with backoff
|
|
if blocked, err := a.isBlockedByDenialBackoff(holdDID, userDID); err == nil && blocked {
|
|
// Still in backoff period - don't query again
|
|
slog.Debug("Blocked by denial backoff cache", "holdDID", holdDID, "userDID", userDID)
|
|
return false, nil
|
|
}
|
|
|
|
// Cache miss or expired - query XRPC endpoint
|
|
slog.Debug("Crew membership cache miss, querying hold", "holdDID", holdDID, "userDID", userDID)
|
|
isCrew, err := a.isCrewMemberNoCache(ctx, holdDID, userDID)
|
|
if err != nil {
|
|
slog.Warn("Crew membership query error", "error", err, "holdDID", holdDID, "userDID", userDID)
|
|
return false, err
|
|
}
|
|
|
|
// Update cache based on result
|
|
if isCrew {
|
|
// Cache approval for 15 minutes
|
|
slog.Debug("Crew membership approved, caching for 15min", "holdDID", holdDID, "userDID", userDID)
|
|
_ = a.cacheApproval(holdDID, userDID, 15*time.Minute)
|
|
} else {
|
|
// Cache denial with exponential backoff
|
|
slog.Debug("Crew membership denied, caching with backoff", "holdDID", holdDID, "userDID", userDID)
|
|
_ = a.cacheDenial(holdDID, userDID)
|
|
}
|
|
|
|
return isCrew, nil
|
|
}
|
|
|
|
// isCrewMemberNoCache queries XRPC without caching (internal helper)
|
|
func (a *RemoteHoldAuthorizer) isCrewMemberNoCache(ctx context.Context, holdDID, userDID string) (bool, error) {
|
|
// Resolve DID to URL
|
|
holdURL := atproto.ResolveHoldURL(holdDID)
|
|
|
|
// Build XRPC request URL
|
|
// GET /xrpc/com.atproto.repo.listRecords?repo={did}&collection=io.atcr.hold.crew
|
|
xrpcURL := fmt.Sprintf("%s%s?repo=%s&collection=%s",
|
|
holdURL, atproto.RepoListRecords, url.QueryEscape(holdDID), url.QueryEscape(atproto.CrewCollection))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", xrpcURL, nil)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
resp, err := a.httpClient.Do(req)
|
|
if err != nil {
|
|
return false, fmt.Errorf("XRPC request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return false, fmt.Errorf("XRPC request failed: status %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Parse response
|
|
var xrpcResp struct {
|
|
Records []struct {
|
|
URI string `json:"uri"`
|
|
CID string `json:"cid"`
|
|
Value struct {
|
|
Type string `json:"$type"`
|
|
Member string `json:"member"`
|
|
Role string `json:"role"`
|
|
Permissions []string `json:"permissions"`
|
|
AddedAt string `json:"addedAt"`
|
|
} `json:"value"`
|
|
} `json:"records"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&xrpcResp); err != nil {
|
|
return false, fmt.Errorf("failed to decode XRPC response: %w", err)
|
|
}
|
|
|
|
// Check if userDID is in the crew list
|
|
for _, record := range xrpcResp.Records {
|
|
if record.Value.Member == userDID {
|
|
// TODO: Check expiration if set
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// CheckReadAccess implements read authorization using shared logic
|
|
func (a *RemoteHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
|
captain, err := a.GetCaptainRecord(ctx, holdDID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return CheckReadAccessWithCaptain(captain, userDID), nil
|
|
}
|
|
|
|
// CheckWriteAccess implements write authorization using shared logic
|
|
func (a *RemoteHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
|
|
captain, err := a.GetCaptainRecord(ctx, holdDID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
isCrew, err := a.IsCrewMember(ctx, holdDID, userDID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return CheckWriteAccessWithCaptain(captain, userDID, isCrew), nil
|
|
}
|
|
|
|
// nullString converts a string to sql.NullString
|
|
func nullString(s string) sql.NullString {
|
|
if s == "" {
|
|
return sql.NullString{Valid: false}
|
|
}
|
|
return sql.NullString{String: s, Valid: true}
|
|
}
|
|
|
|
// getCachedApproval checks if user has a cached crew approval
|
|
func (a *RemoteHoldAuthorizer) getCachedApproval(holdDID, userDID string) (bool, error) {
|
|
query := `
|
|
SELECT expires_at
|
|
FROM hold_crew_approvals
|
|
WHERE hold_did = ? AND user_did = ?
|
|
`
|
|
|
|
var expiresAt time.Time
|
|
err := a.db.QueryRow(query, holdDID, userDID).Scan(&expiresAt)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return false, nil // Cache miss
|
|
}
|
|
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Check if approval has expired
|
|
if time.Now().After(expiresAt) {
|
|
// Expired - clean up
|
|
_ = a.deleteCachedApproval(holdDID, userDID)
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// cacheApproval stores a crew approval with TTL
|
|
func (a *RemoteHoldAuthorizer) cacheApproval(holdDID, userDID string, ttl time.Duration) error {
|
|
query := `
|
|
INSERT INTO hold_crew_approvals (hold_did, user_did, approved_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(hold_did, user_did) DO UPDATE SET
|
|
approved_at = excluded.approved_at,
|
|
expires_at = excluded.expires_at
|
|
`
|
|
|
|
now := time.Now()
|
|
expiresAt := now.Add(ttl)
|
|
|
|
_, err := a.db.Exec(query, holdDID, userDID, now, expiresAt)
|
|
return err
|
|
}
|
|
|
|
// deleteCachedApproval removes an expired approval
|
|
func (a *RemoteHoldAuthorizer) deleteCachedApproval(holdDID, userDID string) error {
|
|
query := `DELETE FROM hold_crew_approvals WHERE hold_did = ? AND user_did = ?`
|
|
_, err := a.db.Exec(query, holdDID, userDID)
|
|
return err
|
|
}
|
|
|
|
// isBlockedByDenialBackoff checks if user is in denial backoff period
|
|
// Checks in-memory cache first (for 10s first denials), then DB (for longer backoffs)
|
|
func (a *RemoteHoldAuthorizer) isBlockedByDenialBackoff(holdDID, userDID string) (bool, error) {
|
|
// Check in-memory cache first (first denials with configurable backoff)
|
|
key := fmt.Sprintf("%s:%s", holdDID, userDID)
|
|
if val, ok := a.recentDenials.Load(key); ok {
|
|
entry := val.(denialEntry)
|
|
// Check if still within first denial backoff period
|
|
if time.Since(entry.timestamp) < a.firstDenialBackoff {
|
|
return true, nil // Still blocked by in-memory first denial
|
|
}
|
|
}
|
|
|
|
// Check database for longer backoffs (second+ denials)
|
|
query := `
|
|
SELECT next_retry_at
|
|
FROM hold_crew_denials
|
|
WHERE hold_did = ? AND user_did = ?
|
|
`
|
|
|
|
var nextRetryAt time.Time
|
|
err := a.db.QueryRow(query, holdDID, userDID).Scan(&nextRetryAt)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return false, nil // No denial record
|
|
}
|
|
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Check if still in backoff period
|
|
if time.Now().Before(nextRetryAt) {
|
|
return true, nil // Still blocked
|
|
}
|
|
|
|
// Backoff period expired - can retry
|
|
return false, nil
|
|
}
|
|
|
|
// cacheDenial stores or updates a denial with exponential backoff
|
|
// First denial: in-memory only (configurable backoff, default 10s)
|
|
// Second+ denial: database with exponential backoff (configurable, default 1m/5m/15m/1h)
|
|
func (a *RemoteHoldAuthorizer) cacheDenial(holdDID, userDID string) error {
|
|
key := fmt.Sprintf("%s:%s", holdDID, userDID)
|
|
|
|
// Check if this is a first denial (not in memory, not in DB)
|
|
_, inMemory := a.recentDenials.Load(key)
|
|
|
|
var denialCount int
|
|
query := `SELECT denial_count FROM hold_crew_denials WHERE hold_did = ? AND user_did = ?`
|
|
err := a.db.QueryRow(query, holdDID, userDID).Scan(&denialCount)
|
|
|
|
inDB := err != sql.ErrNoRows
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
// If not in memory and not in DB, this is the first denial
|
|
if !inMemory && !inDB {
|
|
// First denial: store only in memory with configurable backoff
|
|
a.recentDenials.Store(key, denialEntry{timestamp: time.Now()})
|
|
return nil
|
|
}
|
|
|
|
// Second+ denial: persist to database with exponential backoff
|
|
denialCount++
|
|
backoff := a.getBackoffDuration(denialCount)
|
|
now := time.Now()
|
|
nextRetry := now.Add(backoff)
|
|
|
|
// Upsert denial record
|
|
upsertQuery := `
|
|
INSERT INTO hold_crew_denials (hold_did, user_did, denial_count, next_retry_at, last_denied_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(hold_did, user_did) DO UPDATE SET
|
|
denial_count = excluded.denial_count,
|
|
next_retry_at = excluded.next_retry_at,
|
|
last_denied_at = excluded.last_denied_at
|
|
`
|
|
|
|
_, err = a.db.Exec(upsertQuery, holdDID, userDID, denialCount, nextRetry, now)
|
|
|
|
// Remove from in-memory cache since we're now tracking in DB
|
|
a.recentDenials.Delete(key)
|
|
|
|
return err
|
|
}
|
|
|
|
// getBackoffDuration returns the backoff duration based on denial count
|
|
// Note: First denial is in-memory only and not tracked by this function
|
|
// This function handles second+ denials using configurable durations
|
|
func (a *RemoteHoldAuthorizer) getBackoffDuration(denialCount int) time.Duration {
|
|
backoffs := a.dbBackoffDurations
|
|
|
|
idx := denialCount - 1
|
|
if idx >= len(backoffs) {
|
|
idx = len(backoffs) - 1
|
|
}
|
|
|
|
return backoffs[idx]
|
|
}
|