Files
at-container-registry/pkg/appview/db/session_store.go
2025-10-25 13:30:07 -05:00

223 lines
5.5 KiB
Go

package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"fmt"
"log/slog"
"net/http"
"time"
)
// Session represents a user session
// Compatible with pkg/appview/session.Session
type Session struct {
ID string
DID string
Handle string
PDSEndpoint string
OAuthSessionID string // Links to oauth_sessions.session_id
ExpiresAt time.Time
}
// SessionStoreInterface defines the session storage interface
// Both db.SessionStore and session.Store implement this
type SessionStoreInterface interface {
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error)
Get(id string) (*Session, bool)
Delete(id string)
Cleanup()
}
// SessionStore manages user sessions with SQLite persistence
type SessionStore struct {
db *sql.DB
}
// NewSessionStore creates a new SQLite-backed session store
func NewSessionStore(db *sql.DB) *SessionStore {
return &SessionStore{db: db}
}
// Create creates a new session and returns the session ID
func (s *SessionStore) Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error) {
return s.CreateWithOAuth(did, handle, pdsEndpoint, "", duration)
}
// CreateWithOAuth creates a new session with OAuth sessionID and returns the session ID
func (s *SessionStore) CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) {
// Generate random session ID
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
sessionID := base64.URLEncoding.EncodeToString(b)
expiresAt := time.Now().Add(duration)
_, err := s.db.Exec(`
INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
`, sessionID, did, handle, pdsEndpoint, oauthSessionID, expiresAt)
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
return sessionID, nil
}
// Get retrieves a session by ID
func (s *SessionStore) Get(id string) (*Session, bool) {
var sess Session
err := s.db.QueryRow(`
SELECT id, did, handle, pds_endpoint, oauth_session_id, expires_at
FROM ui_sessions
WHERE id = ?
`, id).Scan(&sess.ID, &sess.DID, &sess.Handle, &sess.PDSEndpoint, &sess.OAuthSessionID, &sess.ExpiresAt)
if err == sql.ErrNoRows {
return nil, false
}
if err != nil {
slog.Warn("Failed to query session", "error", err)
return nil, false
}
// Check if expired
if time.Now().After(sess.ExpiresAt) {
return nil, false
}
return &sess, true
}
// Extend extends a session's expiration time
func (s *SessionStore) Extend(id string, duration time.Duration) error {
expiresAt := time.Now().Add(duration)
result, err := s.db.Exec(`
UPDATE ui_sessions
SET expires_at = ?
WHERE id = ?
`, expiresAt, id)
if err != nil {
return fmt.Errorf("failed to extend session: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("session not found: %s", id)
}
return nil
}
// Delete removes a session
func (s *SessionStore) Delete(id string) {
_, err := s.db.Exec(`
DELETE FROM ui_sessions WHERE id = ?
`, id)
if err != nil {
slog.Warn("Failed to delete session", "error", err)
}
}
// DeleteByDID removes all sessions for a given DID
// This is useful when OAuth refresh fails and we need to force re-authentication
func (s *SessionStore) DeleteByDID(did string) {
result, err := s.db.Exec(`
DELETE FROM ui_sessions WHERE did = ?
`, did)
if err != nil {
slog.Warn("Failed to delete sessions for DID", "did", did, "error", err)
return
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Deleted UI sessions for DID due to OAuth failure", "count", deleted, "did", did)
}
}
// Cleanup removes expired sessions
func (s *SessionStore) Cleanup() {
result, err := s.db.Exec(`
DELETE FROM ui_sessions
WHERE expires_at < datetime('now')
`)
if err != nil {
slog.Warn("Failed to cleanup sessions", "error", err)
return
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Cleaned up expired UI sessions", "count", deleted)
}
}
// CleanupContext is a context-aware version of Cleanup for background workers
func (s *SessionStore) CleanupContext(ctx context.Context) error {
result, err := s.db.ExecContext(ctx, `
DELETE FROM ui_sessions
WHERE expires_at < datetime('now')
`)
if err != nil {
return fmt.Errorf("failed to cleanup sessions: %w", err)
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Cleaned up expired UI sessions", "count", deleted)
}
return nil
}
// Cookie helper functions (compatible with pkg/appview/session package)
// SetCookie sets the session cookie
func SetCookie(w http.ResponseWriter, sessionID string, maxAge int) {
http.SetCookie(w, &http.Cookie{
Name: "atcr_session",
Value: sessionID,
Path: "/",
MaxAge: maxAge,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
}
// ClearCookie clears the session cookie
func ClearCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: "atcr_session",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
}
// GetSessionID gets session ID from cookie
func GetSessionID(r *http.Request) (string, bool) {
cookie, err := r.Cookie("atcr_session")
if err != nil {
return "", false
}
return cookie.Value, true
}