mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 16:40:29 +00:00
223 lines
5.5 KiB
Go
223 lines
5.5 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// Session represents a user session
|
|
// Compatible with pkg/appview/session.Session
|
|
type Session struct {
|
|
ID string
|
|
DID string
|
|
Handle string
|
|
PDSEndpoint string
|
|
OAuthSessionID string // Links to oauth_sessions.session_id
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// SessionStoreInterface defines the session storage interface
|
|
// Both db.SessionStore and session.Store implement this
|
|
type SessionStoreInterface interface {
|
|
Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
|
|
CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error)
|
|
Get(id string) (*Session, bool)
|
|
Delete(id string)
|
|
Cleanup()
|
|
}
|
|
|
|
// SessionStore manages user sessions with SQLite persistence
|
|
type SessionStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewSessionStore creates a new SQLite-backed session store
|
|
func NewSessionStore(db *sql.DB) *SessionStore {
|
|
return &SessionStore{db: db}
|
|
}
|
|
|
|
// Create creates a new session and returns the session ID
|
|
func (s *SessionStore) Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error) {
|
|
return s.CreateWithOAuth(did, handle, pdsEndpoint, "", duration)
|
|
}
|
|
|
|
// CreateWithOAuth creates a new session with OAuth sessionID and returns the session ID
|
|
func (s *SessionStore) CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) {
|
|
// Generate random session ID
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
|
}
|
|
|
|
sessionID := base64.URLEncoding.EncodeToString(b)
|
|
expiresAt := time.Now().Add(duration)
|
|
|
|
_, err := s.db.Exec(`
|
|
INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
|
|
`, sessionID, did, handle, pdsEndpoint, oauthSessionID, expiresAt)
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
|
|
return sessionID, nil
|
|
}
|
|
|
|
// Get retrieves a session by ID
|
|
func (s *SessionStore) Get(id string) (*Session, bool) {
|
|
var sess Session
|
|
|
|
err := s.db.QueryRow(`
|
|
SELECT id, did, handle, pds_endpoint, oauth_session_id, expires_at
|
|
FROM ui_sessions
|
|
WHERE id = ?
|
|
`, id).Scan(&sess.ID, &sess.DID, &sess.Handle, &sess.PDSEndpoint, &sess.OAuthSessionID, &sess.ExpiresAt)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, false
|
|
}
|
|
if err != nil {
|
|
slog.Warn("Failed to query session", "error", err)
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(sess.ExpiresAt) {
|
|
return nil, false
|
|
}
|
|
|
|
return &sess, true
|
|
}
|
|
|
|
// Extend extends a session's expiration time
|
|
func (s *SessionStore) Extend(id string, duration time.Duration) error {
|
|
expiresAt := time.Now().Add(duration)
|
|
|
|
result, err := s.db.Exec(`
|
|
UPDATE ui_sessions
|
|
SET expires_at = ?
|
|
WHERE id = ?
|
|
`, expiresAt, id)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to extend session: %w", err)
|
|
}
|
|
|
|
rows, _ := result.RowsAffected()
|
|
if rows == 0 {
|
|
return fmt.Errorf("session not found: %s", id)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete removes a session
|
|
func (s *SessionStore) Delete(id string) {
|
|
_, err := s.db.Exec(`
|
|
DELETE FROM ui_sessions WHERE id = ?
|
|
`, id)
|
|
|
|
if err != nil {
|
|
slog.Warn("Failed to delete session", "error", err)
|
|
}
|
|
}
|
|
|
|
// DeleteByDID removes all sessions for a given DID
|
|
// This is useful when OAuth refresh fails and we need to force re-authentication
|
|
func (s *SessionStore) DeleteByDID(did string) {
|
|
result, err := s.db.Exec(`
|
|
DELETE FROM ui_sessions WHERE did = ?
|
|
`, did)
|
|
|
|
if err != nil {
|
|
slog.Warn("Failed to delete sessions for DID", "did", did, "error", err)
|
|
return
|
|
}
|
|
|
|
deleted, _ := result.RowsAffected()
|
|
if deleted > 0 {
|
|
slog.Info("Deleted UI sessions for DID due to OAuth failure", "count", deleted, "did", did)
|
|
}
|
|
}
|
|
|
|
// Cleanup removes expired sessions
|
|
func (s *SessionStore) Cleanup() {
|
|
result, err := s.db.Exec(`
|
|
DELETE FROM ui_sessions
|
|
WHERE expires_at < datetime('now')
|
|
`)
|
|
|
|
if err != nil {
|
|
slog.Warn("Failed to cleanup sessions", "error", err)
|
|
return
|
|
}
|
|
|
|
deleted, _ := result.RowsAffected()
|
|
if deleted > 0 {
|
|
slog.Info("Cleaned up expired UI sessions", "count", deleted)
|
|
}
|
|
}
|
|
|
|
// CleanupContext is a context-aware version of Cleanup for background workers
|
|
func (s *SessionStore) CleanupContext(ctx context.Context) error {
|
|
result, err := s.db.ExecContext(ctx, `
|
|
DELETE FROM ui_sessions
|
|
WHERE expires_at < datetime('now')
|
|
`)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to cleanup sessions: %w", err)
|
|
}
|
|
|
|
deleted, _ := result.RowsAffected()
|
|
if deleted > 0 {
|
|
slog.Info("Cleaned up expired UI sessions", "count", deleted)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Cookie helper functions (compatible with pkg/appview/session package)
|
|
|
|
// SetCookie sets the session cookie
|
|
func SetCookie(w http.ResponseWriter, sessionID string, maxAge int) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: maxAge,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
}
|
|
|
|
// ClearCookie clears the session cookie
|
|
func ClearCookie(w http.ResponseWriter) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
}
|
|
|
|
// GetSessionID gets session ID from cookie
|
|
func GetSessionID(r *http.Request) (string, bool) {
|
|
cookie, err := r.Cookie("atcr_session")
|
|
if err != nil {
|
|
return "", false
|
|
}
|
|
return cookie.Value, true
|
|
}
|