mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-27 03:35:10 +00:00
536 lines
14 KiB
Go
536 lines
14 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// setupSessionTestDB creates an in-memory SQLite database for testing
|
|
func setupSessionTestDB(t *testing.T) *SessionStore {
|
|
t.Helper()
|
|
// Use a named in-memory DB unique to this test to ensure isolation between tests
|
|
safeName := strings.ReplaceAll(t.Name(), "/", "_")
|
|
db, err := InitDB(fmt.Sprintf("file:%s?mode=memory&cache=shared", safeName), LibsqlConfig{})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize test database: %v", err)
|
|
}
|
|
// Limit to single connection to avoid race conditions in tests
|
|
db.SetMaxOpenConns(1)
|
|
t.Cleanup(func() {
|
|
db.Close()
|
|
})
|
|
return NewSessionStore(db)
|
|
}
|
|
|
|
// createSessionTestUser creates a test user in the database
|
|
func createSessionTestUser(t *testing.T, store *SessionStore, did, handle string) {
|
|
t.Helper()
|
|
_, err := store.db.Exec(`
|
|
INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen)
|
|
VALUES (?, ?, ?, datetime('now'))
|
|
`, did, handle, "https://pds.example.com")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test user: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSession_Struct(t *testing.T) {
|
|
sess := &Session{
|
|
ID: "test-session",
|
|
DID: "did:plc:test",
|
|
Handle: "alice.bsky.social",
|
|
PDSEndpoint: "https://bsky.social",
|
|
OAuthSessionID: "oauth-123",
|
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
|
}
|
|
|
|
if sess.DID != "did:plc:test" {
|
|
t.Errorf("Expected DID, got %q", sess.DID)
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_Create tests session creation without OAuth
|
|
func TestSessionStore_Create(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
if sessionID == "" {
|
|
t.Error("Create() returned empty session ID")
|
|
}
|
|
|
|
// Verify session can be retrieved
|
|
sess, found := store.Get(sessionID)
|
|
if !found {
|
|
t.Error("Created session not found")
|
|
}
|
|
if sess == nil {
|
|
t.Fatal("Session is nil")
|
|
}
|
|
if sess.DID != "did:plc:alice123" {
|
|
t.Errorf("DID = %v, want did:plc:alice123", sess.DID)
|
|
}
|
|
if sess.Handle != "alice.bsky.social" {
|
|
t.Errorf("Handle = %v, want alice.bsky.social", sess.Handle)
|
|
}
|
|
if sess.OAuthSessionID != "" {
|
|
t.Errorf("OAuthSessionID should be empty, got %v", sess.OAuthSessionID)
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_CreateWithOAuth tests session creation with OAuth
|
|
func TestSessionStore_CreateWithOAuth(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
oauthSessionID := "oauth-123"
|
|
sessionID, err := store.CreateWithOAuth("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", oauthSessionID, 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("CreateWithOAuth() error = %v", err)
|
|
}
|
|
|
|
if sessionID == "" {
|
|
t.Error("CreateWithOAuth() returned empty session ID")
|
|
}
|
|
|
|
// Verify session has OAuth session ID
|
|
sess, found := store.Get(sessionID)
|
|
if !found {
|
|
t.Error("Created session not found")
|
|
}
|
|
if sess.OAuthSessionID != oauthSessionID {
|
|
t.Errorf("OAuthSessionID = %v, want %v", sess.OAuthSessionID, oauthSessionID)
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_Get tests retrieving sessions
|
|
func TestSessionStore_Get(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Create a valid session
|
|
validID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Create a session and manually expire it
|
|
expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Manually update expiration to the past
|
|
_, err = store.db.Exec(`
|
|
UPDATE ui_sessions
|
|
SET expires_at = datetime('now', '-1 hour')
|
|
WHERE id = ?
|
|
`, expiredID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update expiration: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
sessionID string
|
|
wantFound bool
|
|
}{
|
|
{
|
|
name: "valid session",
|
|
sessionID: validID,
|
|
wantFound: true,
|
|
},
|
|
{
|
|
name: "expired session",
|
|
sessionID: expiredID,
|
|
wantFound: false,
|
|
},
|
|
{
|
|
name: "non-existent session",
|
|
sessionID: "non-existent-id",
|
|
wantFound: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sess, found := store.Get(tt.sessionID)
|
|
if found != tt.wantFound {
|
|
t.Errorf("Get() found = %v, want %v", found, tt.wantFound)
|
|
}
|
|
if tt.wantFound && sess == nil {
|
|
t.Error("Expected session, got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_Extend tests extending session expiration
|
|
func TestSessionStore_Extend(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Get initial expiration
|
|
sess1, _ := store.Get(sessionID)
|
|
initialExpiry := sess1.ExpiresAt
|
|
|
|
// Wait a bit to ensure time difference
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
// Extend session
|
|
err = store.Extend(sessionID, 2*time.Hour)
|
|
if err != nil {
|
|
t.Errorf("Extend() error = %v", err)
|
|
}
|
|
|
|
// Verify expiration was updated
|
|
sess2, found := store.Get(sessionID)
|
|
if !found {
|
|
t.Fatal("Session not found after extend")
|
|
}
|
|
if !sess2.ExpiresAt.After(initialExpiry) {
|
|
t.Error("ExpiresAt should be later after extend")
|
|
}
|
|
|
|
// Test extending non-existent session
|
|
err = store.Extend("non-existent-id", 1*time.Hour)
|
|
if err == nil {
|
|
t.Error("Expected error when extending non-existent session")
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), "not found") {
|
|
t.Errorf("Expected 'not found' error, got %v", err)
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_Delete tests deleting a session
|
|
func TestSessionStore_Delete(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Verify session exists
|
|
_, found := store.Get(sessionID)
|
|
if !found {
|
|
t.Fatal("Session should exist before delete")
|
|
}
|
|
|
|
// Delete session
|
|
store.Delete(sessionID)
|
|
|
|
// Verify session is gone
|
|
_, found = store.Get(sessionID)
|
|
if found {
|
|
t.Error("Session should not exist after delete")
|
|
}
|
|
|
|
// Deleting non-existent session should not error
|
|
store.Delete("non-existent-id")
|
|
}
|
|
|
|
// TestSessionStore_DeleteByDID tests deleting all sessions for a DID
|
|
func TestSessionStore_DeleteByDID(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
did := "did:plc:alice123"
|
|
createSessionTestUser(t, store, did, "alice.bsky.social")
|
|
createSessionTestUser(t, store, "did:plc:bob123", "bob.bsky.social")
|
|
|
|
// Create multiple sessions for alice
|
|
sessionIDs := make([]string, 3)
|
|
for i := range 3 {
|
|
id, err := store.Create(did, "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
sessionIDs[i] = id
|
|
}
|
|
|
|
// Create a session for bob
|
|
bobSessionID, err := store.Create("did:plc:bob123", "bob.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Delete all sessions for alice
|
|
store.DeleteByDID(did)
|
|
|
|
// Verify alice's sessions are gone
|
|
for _, id := range sessionIDs {
|
|
_, found := store.Get(id)
|
|
if found {
|
|
t.Errorf("Session %v should have been deleted", id)
|
|
}
|
|
}
|
|
|
|
// Verify bob's session still exists
|
|
_, found := store.Get(bobSessionID)
|
|
if !found {
|
|
t.Error("Bob's session should still exist")
|
|
}
|
|
|
|
// Deleting sessions for non-existent DID should not error
|
|
store.DeleteByDID("did:plc:nonexistent")
|
|
}
|
|
|
|
// TestSessionStore_Cleanup tests removing expired sessions
|
|
func TestSessionStore_Cleanup(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Create valid session by inserting directly with SQLite datetime format
|
|
validID := "valid-session-id"
|
|
_, err := store.db.Exec(`
|
|
INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, datetime('now', '+1 hour'), datetime('now'))
|
|
`, validID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid session: %v", err)
|
|
}
|
|
|
|
// Create expired session
|
|
expiredID := "expired-session-id"
|
|
_, err = store.db.Exec(`
|
|
INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, datetime('now', '-1 hour'), datetime('now'))
|
|
`, expiredID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create expired session: %v", err)
|
|
}
|
|
|
|
// Verify we have 2 sessions before cleanup
|
|
var countBefore int
|
|
err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions").Scan(&countBefore)
|
|
if err != nil {
|
|
t.Fatalf("Query error: %v", err)
|
|
}
|
|
if countBefore != 2 {
|
|
t.Fatalf("Expected 2 sessions before cleanup, got %d", countBefore)
|
|
}
|
|
|
|
// Run cleanup
|
|
store.Cleanup()
|
|
|
|
// Verify valid session still exists in database
|
|
var countValid int
|
|
err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", validID).Scan(&countValid)
|
|
if err != nil {
|
|
t.Fatalf("Query error: %v", err)
|
|
}
|
|
if countValid != 1 {
|
|
t.Errorf("Valid session should still exist in database, count = %d", countValid)
|
|
}
|
|
|
|
// Verify expired session was cleaned up
|
|
var countExpired int
|
|
err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&countExpired)
|
|
if err != nil {
|
|
t.Fatalf("Query error: %v", err)
|
|
}
|
|
if countExpired != 0 {
|
|
t.Error("Expired session should have been deleted from database")
|
|
}
|
|
|
|
// Verify we can still get the valid session
|
|
_, found := store.Get(validID)
|
|
if !found {
|
|
t.Error("Valid session should be retrievable after cleanup")
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_CleanupContext tests context-aware cleanup
|
|
func TestSessionStore_CleanupContext(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Create a session and manually expire it
|
|
expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
|
|
// Manually update expiration to the past
|
|
_, err = store.db.Exec(`
|
|
UPDATE ui_sessions
|
|
SET expires_at = datetime('now', '-1 hour')
|
|
WHERE id = ?
|
|
`, expiredID)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update expiration: %v", err)
|
|
}
|
|
|
|
// Run context-aware cleanup
|
|
ctx := context.Background()
|
|
err = store.CleanupContext(ctx)
|
|
if err != nil {
|
|
t.Errorf("CleanupContext() error = %v", err)
|
|
}
|
|
|
|
// Verify expired session was cleaned up
|
|
var count int
|
|
err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&count)
|
|
if err != nil {
|
|
t.Fatalf("Query error: %v", err)
|
|
}
|
|
if count != 0 {
|
|
t.Error("Expired session should have been deleted from database")
|
|
}
|
|
}
|
|
|
|
// TestSetCookie tests setting session cookie
|
|
func TestSetCookie(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
sessionID := "test-session-id"
|
|
maxAge := 3600
|
|
|
|
SetCookie(w, sessionID, maxAge)
|
|
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) != 1 {
|
|
t.Fatalf("Expected 1 cookie, got %d", len(cookies))
|
|
}
|
|
|
|
cookie := cookies[0]
|
|
if cookie.Name != "atcr_session" {
|
|
t.Errorf("Name = %v, want atcr_session", cookie.Name)
|
|
}
|
|
if cookie.Value != sessionID {
|
|
t.Errorf("Value = %v, want %v", cookie.Value, sessionID)
|
|
}
|
|
if cookie.MaxAge != maxAge {
|
|
t.Errorf("MaxAge = %v, want %v", cookie.MaxAge, maxAge)
|
|
}
|
|
if !cookie.HttpOnly {
|
|
t.Error("HttpOnly should be true")
|
|
}
|
|
if !cookie.Secure {
|
|
t.Error("Secure should be true")
|
|
}
|
|
if cookie.SameSite != http.SameSiteLaxMode {
|
|
t.Errorf("SameSite = %v, want Lax", cookie.SameSite)
|
|
}
|
|
if cookie.Path != "/" {
|
|
t.Errorf("Path = %v, want /", cookie.Path)
|
|
}
|
|
}
|
|
|
|
// TestClearCookie tests clearing session cookie
|
|
func TestClearCookie(t *testing.T) {
|
|
w := httptest.NewRecorder()
|
|
|
|
ClearCookie(w)
|
|
|
|
cookies := w.Result().Cookies()
|
|
if len(cookies) != 1 {
|
|
t.Fatalf("Expected 1 cookie, got %d", len(cookies))
|
|
}
|
|
|
|
cookie := cookies[0]
|
|
if cookie.Name != "atcr_session" {
|
|
t.Errorf("Name = %v, want atcr_session", cookie.Name)
|
|
}
|
|
if cookie.Value != "" {
|
|
t.Errorf("Value should be empty, got %v", cookie.Value)
|
|
}
|
|
if cookie.MaxAge != -1 {
|
|
t.Errorf("MaxAge = %v, want -1", cookie.MaxAge)
|
|
}
|
|
if !cookie.HttpOnly {
|
|
t.Error("HttpOnly should be true")
|
|
}
|
|
if !cookie.Secure {
|
|
t.Error("Secure should be true")
|
|
}
|
|
}
|
|
|
|
// TestGetSessionID tests retrieving session ID from cookie
|
|
func TestGetSessionID(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cookie *http.Cookie
|
|
wantID string
|
|
wantFound bool
|
|
}{
|
|
{
|
|
name: "valid cookie",
|
|
cookie: &http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: "test-session-id",
|
|
},
|
|
wantID: "test-session-id",
|
|
wantFound: true,
|
|
},
|
|
{
|
|
name: "no cookie",
|
|
cookie: nil,
|
|
wantID: "",
|
|
wantFound: false,
|
|
},
|
|
{
|
|
name: "wrong cookie name",
|
|
cookie: &http.Cookie{
|
|
Name: "other_cookie",
|
|
Value: "test-value",
|
|
},
|
|
wantID: "",
|
|
wantFound: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tt.cookie != nil {
|
|
req.AddCookie(tt.cookie)
|
|
}
|
|
|
|
id, found := GetSessionID(req)
|
|
if found != tt.wantFound {
|
|
t.Errorf("GetSessionID() found = %v, want %v", found, tt.wantFound)
|
|
}
|
|
if id != tt.wantID {
|
|
t.Errorf("GetSessionID() id = %v, want %v", id, tt.wantID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionStore_SessionIDUniqueness tests that generated session IDs are unique
|
|
func TestSessionStore_SessionIDUniqueness(t *testing.T) {
|
|
store := setupSessionTestDB(t)
|
|
createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Generate multiple session IDs
|
|
ids := make(map[string]bool)
|
|
for range 100 {
|
|
id, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("Create() error = %v", err)
|
|
}
|
|
if ids[id] {
|
|
t.Errorf("Duplicate session ID generated: %v", id)
|
|
}
|
|
ids[id] = true
|
|
}
|
|
|
|
if len(ids) != 100 {
|
|
t.Errorf("Expected 100 unique IDs, got %d", len(ids))
|
|
}
|
|
}
|