192 lines
5.9 KiB
Go
192 lines
5.9 KiB
Go
// Package auth provides authentication and authorization for ATCR, including
|
|
// ATProto session validation, hold authorization (captain/crew membership),
|
|
// scope parsing, and token caching for OAuth and service tokens.
|
|
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"atcr.io/pkg/atproto"
|
|
)
|
|
|
|
// Sentinel errors for authentication failures
|
|
var (
|
|
// ErrIdentityResolution indicates handle/DID resolution failed
|
|
ErrIdentityResolution = errors.New("identity resolution failed")
|
|
// ErrInvalidCredentials indicates PDS returned 401 (bad password/app-password)
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
// ErrPDSUnavailable indicates PDS is unreachable or returned a server error
|
|
ErrPDSUnavailable = errors.New("PDS unavailable")
|
|
)
|
|
|
|
// CachedSession represents a cached session
|
|
type CachedSession struct {
|
|
DID string
|
|
Handle string
|
|
PDS string
|
|
AccessToken string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// SessionValidator validates ATProto credentials
|
|
type SessionValidator struct {
|
|
httpClient *http.Client
|
|
cache map[string]*CachedSession
|
|
cacheMu sync.RWMutex
|
|
}
|
|
|
|
// NewSessionValidator creates a new ATProto session validator
|
|
func NewSessionValidator() *SessionValidator {
|
|
return &SessionValidator{
|
|
httpClient: &http.Client{},
|
|
cache: make(map[string]*CachedSession),
|
|
}
|
|
}
|
|
|
|
// getCacheKey generates a cache key from username and password
|
|
func getCacheKey(username, password string) string {
|
|
h := sha256.New()
|
|
h.Write([]byte(username + ":" + password))
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// getCachedSession retrieves a cached session if valid
|
|
func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) {
|
|
v.cacheMu.RLock()
|
|
defer v.cacheMu.RUnlock()
|
|
|
|
session, ok := v.cache[cacheKey]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired (with 5 minute buffer)
|
|
if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) {
|
|
return nil, false
|
|
}
|
|
|
|
return session, true
|
|
}
|
|
|
|
// setCachedSession stores a session in the cache
|
|
func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) {
|
|
v.cacheMu.Lock()
|
|
defer v.cacheMu.Unlock()
|
|
v.cache[cacheKey] = session
|
|
}
|
|
|
|
// SessionResponse represents the response from createSession
|
|
type SessionResponse struct {
|
|
DID string `json:"did"`
|
|
Handle string `json:"handle"`
|
|
AccessJWT string `json:"accessJwt"`
|
|
RefreshJWT string `json:"refreshJwt"`
|
|
Email string `json:"email,omitempty"`
|
|
AccessToken string `json:"access_token,omitempty"` // Alternative field name
|
|
}
|
|
|
|
// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token
|
|
func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) {
|
|
// Check cache first
|
|
cacheKey := getCacheKey(identifier, password)
|
|
if cached, ok := v.getCachedSession(cacheKey); ok {
|
|
slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID)
|
|
return cached.DID, cached.Handle, cached.AccessToken, nil
|
|
}
|
|
|
|
slog.Debug("No cached session, creating new session", "identifier", identifier)
|
|
|
|
// Resolve identifier to PDS endpoint
|
|
_, _, pds, err := atproto.ResolveIdentity(ctx, identifier)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("%w: %v", ErrIdentityResolution, err)
|
|
}
|
|
|
|
// Create session
|
|
sessionResp, err := v.createSession(ctx, pds, identifier, password)
|
|
if err != nil {
|
|
// Pass through typed errors from createSession
|
|
return "", "", "", err
|
|
}
|
|
|
|
// Cache the session (ATProto sessions typically last 2 hours)
|
|
v.setCachedSession(cacheKey, &CachedSession{
|
|
DID: sessionResp.DID,
|
|
Handle: sessionResp.Handle,
|
|
PDS: pds,
|
|
AccessToken: sessionResp.AccessJWT,
|
|
ExpiresAt: time.Now().Add(2 * time.Hour),
|
|
})
|
|
slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID)
|
|
|
|
return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil
|
|
}
|
|
|
|
// createSession calls com.atproto.server.createSession
|
|
func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) {
|
|
payload := map[string]string{
|
|
"identifier": identifier,
|
|
"password": password,
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession)
|
|
slog.Debug("Creating ATProto session", "url", url)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
slog.Debug("Session creation HTTP request failed", "error", err)
|
|
return nil, fmt.Errorf("%w: %v", ErrPDSUnavailable, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
slog.Debug("Received session creation response", "status", resp.StatusCode)
|
|
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
slog.Debug("Session creation unauthorized", "response", string(bodyBytes))
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
if resp.StatusCode >= 500 {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
slog.Debug("PDS server error", "status", resp.StatusCode, "response", string(bodyBytes))
|
|
return nil, fmt.Errorf("%w: server returned %d", ErrPDSUnavailable, resp.StatusCode)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes))
|
|
return nil, fmt.Errorf("%w: unexpected status %d: %s", ErrPDSUnavailable, resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var sessionResp SessionResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return &sessionResp, nil
|
|
}
|