Files
at-container-registry/pkg/auth/session/handler.go
Evan Jarrett 31e235a2a1 cleanup oauth
2025-10-04 13:50:28 -05:00

171 lines
4.4 KiB
Go

package session
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"strings"
"time"
)
// SessionClaims represents the data stored in a session token
type SessionClaims struct {
DID string `json:"did"`
Handle string `json:"handle"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// Manager handles session token creation and validation
type Manager struct {
secret []byte
ttl time.Duration
}
// NewManager creates a new session manager
func NewManager(secret []byte, ttl time.Duration) *Manager {
return &Manager{
secret: secret,
ttl: ttl,
}
}
// NewManagerWithRandomSecret creates a session manager with a random secret
func NewManagerWithRandomSecret(ttl time.Duration) (*Manager, error) {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("failed to generate secret: %w", err)
}
return NewManager(secret, ttl), nil
}
// NewManagerWithPersistentSecret creates a session manager with a persistent secret
// The secret is stored at secretPath and reused across restarts
func NewManagerWithPersistentSecret(secretPath string, ttl time.Duration) (*Manager, error) {
var secret []byte
// Try to load existing secret
if data, err := os.ReadFile(secretPath); err == nil {
secret = data
fmt.Printf("Loaded existing session secret from %s\n", secretPath)
} else if os.IsNotExist(err) {
// Generate new secret
secret = make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
return nil, fmt.Errorf("failed to generate secret: %w", err)
}
// Save secret for future restarts
if err := os.WriteFile(secretPath, secret, 0600); err != nil {
return nil, fmt.Errorf("failed to save secret: %w", err)
}
fmt.Printf("Generated and saved new session secret to %s\n", secretPath)
} else {
return nil, fmt.Errorf("failed to read secret file: %w", err)
}
return NewManager(secret, ttl), nil
}
// Create generates a new session token for a DID
func (m *Manager) Create(did, handle string) (string, error) {
now := time.Now()
claims := SessionClaims{
DID: did,
Handle: handle,
IssuedAt: now,
ExpiresAt: now.Add(m.ttl),
}
// Marshal claims to JSON
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal claims: %w", err)
}
// Base64 encode claims
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
// Generate HMAC signature
sig := m.sign(claimsB64)
sigB64 := base64.RawURLEncoding.EncodeToString(sig)
// Token format: <claims>.<signature>
token := claimsB64 + "." + sigB64
return token, nil
}
// Validate validates a session token and returns the claims
func (m *Manager) Validate(token string) (*SessionClaims, error) {
// Split token into claims and signature
parts := strings.Split(token, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid token format")
}
claimsB64 := parts[0]
sigB64 := parts[1]
// Verify signature
expectedSig := m.sign(claimsB64)
providedSig, err := base64.RawURLEncoding.DecodeString(sigB64)
if err != nil {
return nil, fmt.Errorf("invalid signature encoding: %w", err)
}
if !hmac.Equal(expectedSig, providedSig) {
return nil, fmt.Errorf("invalid signature")
}
// Decode claims
claimsJSON, err := base64.RawURLEncoding.DecodeString(claimsB64)
if err != nil {
return nil, fmt.Errorf("invalid claims encoding: %w", err)
}
var claims SessionClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return nil, fmt.Errorf("invalid claims format: %w", err)
}
// Check expiration
if time.Now().After(claims.ExpiresAt) {
return nil, fmt.Errorf("token expired")
}
return &claims, nil
}
// sign generates HMAC-SHA256 signature for data
func (m *Manager) sign(data string) []byte {
h := hmac.New(sha256.New, m.secret)
h.Write([]byte(data))
return h.Sum(nil)
}
// GetDID extracts the DID from a token without full validation
// Useful for logging/debugging
func (m *Manager) GetDID(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return "", fmt.Errorf("invalid token format")
}
claimsJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return "", fmt.Errorf("invalid claims encoding: %w", err)
}
var claims SessionClaims
if err := json.Unmarshal(claimsJSON, &claims); err != nil {
return "", fmt.Errorf("invalid claims format: %w", err)
}
return claims.DID, nil
}