274 lines
9.6 KiB
Go
274 lines
9.6 KiB
Go
package token
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"atcr.io/pkg/appview/db"
|
|
"atcr.io/pkg/atproto"
|
|
"atcr.io/pkg/auth"
|
|
"github.com/go-chi/render"
|
|
)
|
|
|
|
// PostAuthCallback is called after successful Basic Auth authentication.
|
|
// Parameters: ctx, did, handle, pdsEndpoint, accessToken
|
|
// This allows AppView to perform business logic (profile creation, etc.)
|
|
// without coupling the token package to AppView-specific dependencies.
|
|
type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error
|
|
|
|
// OAuthSessionValidator validates OAuth sessions before issuing tokens
|
|
// This interface allows the token handler to verify OAuth sessions are usable
|
|
// (not just that they exist) without depending directly on the OAuth implementation.
|
|
type OAuthSessionValidator interface {
|
|
// ValidateSession checks if OAuth session is usable by attempting to load/refresh it
|
|
// Returns nil if session is valid, error if session is invalid/expired/needs re-auth
|
|
ValidateSession(ctx context.Context, did string) error
|
|
}
|
|
|
|
// Handler handles /auth/token requests
|
|
type Handler struct {
|
|
issuer *Issuer
|
|
validator *auth.SessionValidator
|
|
deviceStore *db.DeviceStore // For validating device secrets
|
|
postAuthCallback PostAuthCallback
|
|
oauthSessionValidator OAuthSessionValidator
|
|
}
|
|
|
|
// NewHandler creates a new token handler
|
|
func NewHandler(issuer *Issuer, deviceStore *db.DeviceStore) *Handler {
|
|
return &Handler{
|
|
issuer: issuer,
|
|
validator: auth.NewSessionValidator(),
|
|
deviceStore: deviceStore,
|
|
}
|
|
}
|
|
|
|
// SetPostAuthCallback sets the callback to be invoked after successful Basic Auth authentication
|
|
// This allows AppView to inject business logic without coupling the token package
|
|
func (h *Handler) SetPostAuthCallback(callback PostAuthCallback) {
|
|
h.postAuthCallback = callback
|
|
}
|
|
|
|
// SetOAuthSessionValidator sets the OAuth session validator for validating device auth
|
|
// When set, the handler will validate OAuth sessions are usable before issuing tokens for device auth
|
|
// This prevents the flood of errors that occurs when a stale session is discovered during push
|
|
func (h *Handler) SetOAuthSessionValidator(validator OAuthSessionValidator) {
|
|
h.oauthSessionValidator = validator
|
|
}
|
|
|
|
// TokenResponse represents the response from /auth/token
|
|
type TokenResponse struct {
|
|
Token string `json:"token,omitempty"` // Legacy field
|
|
AccessToken string `json:"access_token,omitempty"` // Standard field
|
|
ExpiresIn int `json:"expires_in,omitempty"`
|
|
IssuedAt string `json:"issued_at,omitempty"`
|
|
}
|
|
|
|
// getBaseURL extracts the base URL from the request, handling proxies
|
|
func getBaseURL(r *http.Request) string {
|
|
baseURL := r.Header.Get("X-Forwarded-Host")
|
|
if baseURL == "" {
|
|
baseURL = r.Host
|
|
}
|
|
if !strings.HasPrefix(baseURL, "http") {
|
|
// Add scheme
|
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
|
baseURL = "https://" + baseURL
|
|
} else {
|
|
baseURL = "http://" + baseURL
|
|
}
|
|
}
|
|
return baseURL
|
|
}
|
|
|
|
// sendAuthError sends a formatted authentication error response
|
|
func sendAuthError(w http.ResponseWriter, r *http.Request, message string) {
|
|
baseURL := getBaseURL(r)
|
|
w.Header().Set("WWW-Authenticate", `Basic realm="ATCR Registry"`)
|
|
http.Error(w, fmt.Sprintf(`%s
|
|
|
|
To authenticate:
|
|
1. Install credential helper: %s/install
|
|
2. Or run: docker login %s
|
|
(use your ATProto handle + app-password)`, message, baseURL, r.Host), http.StatusUnauthorized)
|
|
}
|
|
|
|
// AuthErrorResponse is returned when authentication fails in a way the credential helper can handle
|
|
type AuthErrorResponse struct {
|
|
Error string `json:"error"`
|
|
Message string `json:"message"`
|
|
LoginURL string `json:"login_url,omitempty"`
|
|
}
|
|
|
|
// sendOAuthSessionExpiredError sends a JSON error response when OAuth session is missing
|
|
// This allows the credential helper to detect this specific error and open the browser
|
|
func sendOAuthSessionExpiredError(w http.ResponseWriter, r *http.Request) {
|
|
baseURL := getBaseURL(r)
|
|
loginURL := baseURL + "/auth/oauth/login"
|
|
|
|
w.Header().Set("WWW-Authenticate", `Basic realm="ATCR Registry"`)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
|
|
resp := AuthErrorResponse{
|
|
Error: "oauth_session_expired",
|
|
Message: "OAuth session expired or invalidated. Please re-authenticate in your browser.",
|
|
LoginURL: loginURL,
|
|
}
|
|
render.JSON(w, r, resp)
|
|
}
|
|
|
|
// ServeHTTP handles the token request
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
slog.Debug("Received token request", "method", r.Method, "path", r.URL.Path)
|
|
|
|
// Only accept GET requests (per Docker spec)
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Extract Basic auth credentials
|
|
username, password, ok := r.BasicAuth()
|
|
if !ok {
|
|
slog.Debug("No Basic auth credentials provided")
|
|
sendAuthError(w, r, "authentication required")
|
|
return
|
|
}
|
|
|
|
slog.Debug("Got Basic auth credentials", "username", username, "passwordLength", len(password))
|
|
|
|
// Parse query parameters
|
|
_ = r.URL.Query().Get("service") // service parameter - validated by issuer
|
|
scopeParam := r.URL.Query().Get("scope")
|
|
|
|
// Parse scopes
|
|
var scopes []string
|
|
if scopeParam != "" {
|
|
scopes = strings.Split(scopeParam, " ")
|
|
}
|
|
|
|
access, err := auth.ParseScope(scopes)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("invalid scope: %v", err), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var did string
|
|
var handle string
|
|
var accessToken string
|
|
var authMethod string
|
|
|
|
// 1. Check if it's a device secret (starts with "atcr_device_")
|
|
if strings.HasPrefix(password, "atcr_device_") {
|
|
device, err := h.deviceStore.ValidateDeviceSecret(password)
|
|
if err != nil {
|
|
slog.Debug("Device secret validation failed", "error", err)
|
|
sendAuthError(w, r, "authentication failed")
|
|
return
|
|
}
|
|
|
|
// Validate OAuth session is usable (not just exists)
|
|
// Device secrets are permanent, but they require a working OAuth session to push
|
|
// By validating here, we prevent the flood of errors that occurs when a stale
|
|
// session is discovered during parallel layer uploads
|
|
if h.oauthSessionValidator != nil {
|
|
if err := h.oauthSessionValidator.ValidateSession(r.Context(), device.DID); err != nil {
|
|
slog.Debug("OAuth session validation failed", "did", device.DID, "error", err)
|
|
sendOAuthSessionExpiredError(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
did = device.DID
|
|
handle = device.Handle
|
|
authMethod = AuthMethodOAuth
|
|
// Device is linked to OAuth session via DID
|
|
// OAuth refresher will provide access token when needed via middleware
|
|
} else {
|
|
// 2. Try app password (direct PDS authentication)
|
|
slog.Debug("Trying app password authentication", "username", username)
|
|
did, handle, accessToken, err = h.validator.CreateSessionAndGetToken(r.Context(), username, password)
|
|
if err != nil {
|
|
// Log at WARN level with specific error type
|
|
if errors.Is(err, auth.ErrIdentityResolution) {
|
|
slog.Warn("Identity resolution failed", "error", err, "username", username)
|
|
sendAuthError(w, r, "authentication failed: could not resolve handle")
|
|
} else if errors.Is(err, auth.ErrInvalidCredentials) {
|
|
slog.Warn("Invalid credentials", "username", username)
|
|
sendAuthError(w, r, "authentication failed: invalid credentials")
|
|
} else if errors.Is(err, auth.ErrPDSUnavailable) {
|
|
slog.Warn("PDS unavailable", "error", err, "username", username)
|
|
sendAuthError(w, r, "authentication failed: PDS unavailable")
|
|
} else {
|
|
slog.Warn("Authentication failed", "error", err, "username", username)
|
|
sendAuthError(w, r, "authentication failed")
|
|
}
|
|
return
|
|
}
|
|
|
|
authMethod = AuthMethodAppPassword
|
|
|
|
slog.Debug("App password validated successfully",
|
|
"did", did,
|
|
"handle", handle,
|
|
"accessTokenLength", len(accessToken))
|
|
|
|
// Cache the access token for later use (e.g., when pushing manifests)
|
|
// TTL of 2 hours (ATProto tokens typically last longer)
|
|
auth.GetGlobalTokenCache().Set(did, accessToken, 2*time.Hour)
|
|
slog.Debug("Cached access token", "did", did)
|
|
|
|
// Call post-auth callback for AppView business logic (profile management, etc.)
|
|
if h.postAuthCallback != nil {
|
|
// Resolve PDS endpoint for callback
|
|
_, _, pdsEndpoint, err := atproto.ResolveIdentity(r.Context(), username)
|
|
if err != nil {
|
|
// Log error but don't fail auth - profile management is not critical
|
|
slog.Warn("Failed to resolve PDS for callback", "error", err, "username", username)
|
|
} else {
|
|
if err := h.postAuthCallback(r.Context(), did, handle, pdsEndpoint, accessToken); err != nil {
|
|
// Log error but don't fail auth - business logic is non-critical
|
|
slog.Warn("Post-auth callback failed", "error", err, "did", did)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate that the user has permission for the requested access
|
|
// Use the actual handle from the validated credentials, not the Basic Auth username
|
|
if err := auth.ValidateAccess(did, handle, access); err != nil {
|
|
slog.Debug("Access validation failed", "error", err, "did", did)
|
|
http.Error(w, fmt.Sprintf("access denied: %v", err), http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Issue JWT token
|
|
tokenString, err := h.issuer.Issue(did, access, authMethod)
|
|
if err != nil {
|
|
slog.Error("Failed to issue token", "error", err, "did", did)
|
|
http.Error(w, fmt.Sprintf("failed to issue token: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
slog.Debug("Issued JWT token", "tokenLength", len(tokenString), "did", did, "authMethod", authMethod)
|
|
|
|
// Return token response
|
|
now := time.Now()
|
|
expiresIn := int(h.issuer.expiration.Seconds())
|
|
|
|
resp := TokenResponse{
|
|
Token: tokenString,
|
|
AccessToken: tokenString,
|
|
ExpiresIn: expiresIn,
|
|
IssuedAt: now.Format(time.RFC3339),
|
|
}
|
|
|
|
render.JSON(w, r, resp)
|
|
}
|