Files
at-container-registry/pkg/auth/session.go
2025-10-26 23:08:03 -05:00

174 lines
5.1 KiB
Go

// Package auth provides authentication and authorization for ATCR, including
// ATProto session validation, hold authorization (captain/crew membership),
// scope parsing, and token caching for OAuth and service tokens.
package auth
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
"time"
"atcr.io/pkg/atproto"
)
// CachedSession represents a cached session
type CachedSession struct {
DID string
Handle string
PDS string
AccessToken string
ExpiresAt time.Time
}
// SessionValidator validates ATProto credentials
type SessionValidator struct {
httpClient *http.Client
cache map[string]*CachedSession
cacheMu sync.RWMutex
}
// NewSessionValidator creates a new ATProto session validator
func NewSessionValidator() *SessionValidator {
return &SessionValidator{
httpClient: &http.Client{},
cache: make(map[string]*CachedSession),
}
}
// getCacheKey generates a cache key from username and password
func getCacheKey(username, password string) string {
h := sha256.New()
h.Write([]byte(username + ":" + password))
return hex.EncodeToString(h.Sum(nil))
}
// getCachedSession retrieves a cached session if valid
func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) {
v.cacheMu.RLock()
defer v.cacheMu.RUnlock()
session, ok := v.cache[cacheKey]
if !ok {
return nil, false
}
// Check if expired (with 5 minute buffer)
if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) {
return nil, false
}
return session, true
}
// setCachedSession stores a session in the cache
func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) {
v.cacheMu.Lock()
defer v.cacheMu.Unlock()
v.cache[cacheKey] = session
}
// SessionResponse represents the response from createSession
type SessionResponse struct {
DID string `json:"did"`
Handle string `json:"handle"`
AccessJWT string `json:"accessJwt"`
RefreshJWT string `json:"refreshJwt"`
Email string `json:"email,omitempty"`
AccessToken string `json:"access_token,omitempty"` // Alternative field name
}
// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token
func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) {
// Check cache first
cacheKey := getCacheKey(identifier, password)
if cached, ok := v.getCachedSession(cacheKey); ok {
slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID)
return cached.DID, cached.Handle, cached.AccessToken, nil
}
slog.Debug("No cached session, creating new session", "identifier", identifier)
// Resolve identifier to PDS endpoint
_, _, pds, err := atproto.ResolveIdentity(ctx, identifier)
if err != nil {
return "", "", "", err
}
// Create session
sessionResp, err := v.createSession(ctx, pds, identifier, password)
if err != nil {
return "", "", "", fmt.Errorf("authentication failed: %w", err)
}
// Cache the session (ATProto sessions typically last 2 hours)
v.setCachedSession(cacheKey, &CachedSession{
DID: sessionResp.DID,
Handle: sessionResp.Handle,
PDS: pds,
AccessToken: sessionResp.AccessJWT,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID)
return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil
}
// createSession calls com.atproto.server.createSession
func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) {
payload := map[string]string{
"identifier": identifier,
"password": password,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession)
slog.Debug("Creating ATProto session", "url", url)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := v.httpClient.Do(req)
if err != nil {
slog.Debug("Session creation HTTP request failed", "error", err)
return nil, fmt.Errorf("failed to create session: %w", err)
}
defer resp.Body.Close()
slog.Debug("Received session creation response", "status", resp.StatusCode)
if resp.StatusCode == http.StatusUnauthorized {
bodyBytes, _ := io.ReadAll(resp.Body)
slog.Debug("Session creation unauthorized", "response", string(bodyBytes))
return nil, fmt.Errorf("invalid credentials")
}
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes))
return nil, fmt.Errorf("create session failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var sessionResp SessionResponse
if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &sessionResp, nil
}