Files
2025-10-28 20:39:57 -05:00

396 lines
11 KiB
Go

package middleware
import (
"database/sql"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"atcr.io/pkg/appview/db"
)
func TestGetUser_NoContext(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
user := GetUser(req)
if user != nil {
t.Error("Expected nil user when no context is set")
}
}
// setupTestDB creates an in-memory SQLite database for testing
func setupTestDB(t *testing.T) *sql.DB {
database, err := db.InitDB(":memory:", true)
require.NoError(t, err)
t.Cleanup(func() {
database.Close()
})
return database
}
// TestRequireAuth_ValidSession tests RequireAuth with a valid session
func TestRequireAuth_ValidSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
// Create a user first (required by foreign key)
_, err := database.Exec(
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(),
)
require.NoError(t, err)
// Create a session
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
require.NoError(t, err)
// Create a test handler that checks user context
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
user := GetUser(r)
assert.NotNil(t, user)
assert.Equal(t, "did:plc:test123", user.DID)
assert.Equal(t, "alice.bsky.social", user.Handle)
w.WriteHeader(http.StatusOK)
})
// Wrap with RequireAuth middleware
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
// Create request with session cookie
req := httptest.NewRequest("GET", "/test", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: sessionID,
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.True(t, handlerCalled, "handler should have been called")
assert.Equal(t, http.StatusOK, w.Code)
}
// TestRequireAuth_MissingSession tests RequireAuth redirects when no session
func TestRequireAuth_MissingSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
// Request without session cookie
req := httptest.NewRequest("GET", "/protected", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.False(t, handlerCalled, "handler should not have been called")
assert.Equal(t, http.StatusFound, w.Code)
assert.Contains(t, w.Header().Get("Location"), "/auth/oauth/login")
assert.Contains(t, w.Header().Get("Location"), "return_to=%2Fprotected")
}
// TestRequireAuth_InvalidSession tests RequireAuth redirects when session is invalid
func TestRequireAuth_InvalidSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
w.WriteHeader(http.StatusOK)
})
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
// Request with invalid session ID
req := httptest.NewRequest("GET", "/protected", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: "invalid-session-id",
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.False(t, handlerCalled, "handler should not have been called")
assert.Equal(t, http.StatusFound, w.Code)
assert.Contains(t, w.Header().Get("Location"), "/auth/oauth/login")
}
// TestRequireAuth_WithQueryParams tests RequireAuth preserves query parameters in return_to
func TestRequireAuth_WithQueryParams(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
// Request without session but with query parameters
req := httptest.NewRequest("GET", "/protected?foo=bar&baz=qux", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.Equal(t, http.StatusFound, w.Code)
location := w.Header().Get("Location")
assert.Contains(t, location, "/auth/oauth/login")
assert.Contains(t, location, "return_to=")
// Query parameters should be preserved in return_to
assert.Contains(t, location, "foo%3Dbar")
}
// TestRequireAuth_DatabaseFallback tests fallback to session data when DB lookup has no avatar
func TestRequireAuth_DatabaseFallback(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
// Create a user without avatar (required by foreign key)
_, err := database.Exec(
"INSERT INTO users (did, handle, pds_endpoint, last_seen, avatar) VALUES (?, ?, ?, ?, ?)",
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(), "",
)
require.NoError(t, err)
// Create a session
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
require.NoError(t, err)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
user := GetUser(r)
assert.NotNil(t, user)
assert.Equal(t, "did:plc:test123", user.DID)
assert.Equal(t, "alice.bsky.social", user.Handle)
// User exists in DB but has no avatar - should use DB version
assert.Empty(t, user.Avatar, "avatar should be empty when not set in DB")
w.WriteHeader(http.StatusOK)
})
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: sessionID,
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.True(t, handlerCalled)
assert.Equal(t, http.StatusOK, w.Code)
}
// TestOptionalAuth_ValidSession tests OptionalAuth with valid session
func TestOptionalAuth_ValidSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
// Create a user first (required by foreign key)
_, err := database.Exec(
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(),
)
require.NoError(t, err)
// Create a session
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
require.NoError(t, err)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
user := GetUser(r)
assert.NotNil(t, user, "user should be set when session is valid")
assert.Equal(t, "did:plc:test123", user.DID)
w.WriteHeader(http.StatusOK)
})
middleware := OptionalAuth(store, database)
wrappedHandler := middleware(handler)
req := httptest.NewRequest("GET", "/test", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: sessionID,
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.True(t, handlerCalled)
assert.Equal(t, http.StatusOK, w.Code)
}
// TestOptionalAuth_NoSession tests OptionalAuth continues without user when no session
func TestOptionalAuth_NoSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
user := GetUser(r)
assert.Nil(t, user, "user should be nil when no session")
w.WriteHeader(http.StatusOK)
})
middleware := OptionalAuth(store, database)
wrappedHandler := middleware(handler)
// Request without session cookie
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.True(t, handlerCalled, "handler should still be called")
assert.Equal(t, http.StatusOK, w.Code)
}
// TestOptionalAuth_InvalidSession tests OptionalAuth continues without user when session invalid
func TestOptionalAuth_InvalidSession(t *testing.T) {
database := setupTestDB(t)
store := db.NewSessionStore(database)
handlerCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
user := GetUser(r)
assert.Nil(t, user, "user should be nil when session is invalid")
w.WriteHeader(http.StatusOK)
})
middleware := OptionalAuth(store, database)
wrappedHandler := middleware(handler)
// Request with invalid session ID
req := httptest.NewRequest("GET", "/test", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: "invalid-session-id",
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
assert.True(t, handlerCalled, "handler should still be called")
assert.Equal(t, http.StatusOK, w.Code)
}
// TestMiddleware_ConcurrentAccess tests concurrent requests through middleware
func TestMiddleware_ConcurrentAccess(t *testing.T) {
// Use a shared in-memory database for concurrent access
// (SQLite's default :memory: creates separate DBs per connection)
database, err := db.InitDB("file::memory:?cache=shared", true)
require.NoError(t, err)
t.Cleanup(func() {
database.Close()
})
store := db.NewSessionStore(database)
// Pre-create all users and sessions before concurrent access
// This ensures database is fully initialized before goroutines start
sessionIDs := make([]string, 10)
for i := 0; i < 10; i++ {
did := fmt.Sprintf("did:plc:user%d", i)
handle := fmt.Sprintf("user%d.bsky.social", i)
// Create user first
_, err := database.Exec(
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
did, handle, "https://pds.example.com", time.Now(),
)
require.NoError(t, err)
// Create session
sessionID, err := store.Create(
did,
handle,
"https://pds.example.com",
24*time.Hour,
)
require.NoError(t, err)
sessionIDs[i] = sessionID
}
// All setup complete - now test concurrent access
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r)
if user != nil {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusUnauthorized)
}
})
middleware := RequireAuth(store, database)
wrappedHandler := middleware(handler)
// Collect results from all goroutines
results := make([]int, 10)
var wg sync.WaitGroup
var mu sync.Mutex // Protect results map
for i := 0; i < 10; i++ {
wg.Add(1)
go func(index int, sessionID string) {
defer wg.Done()
req := httptest.NewRequest("GET", "/test", nil)
req.AddCookie(&http.Cookie{
Name: "atcr_session",
Value: sessionID,
})
w := httptest.NewRecorder()
wrappedHandler.ServeHTTP(w, req)
mu.Lock()
results[index] = w.Code
mu.Unlock()
}(i, sessionIDs[i])
}
wg.Wait()
// Check all results after concurrent execution
// Note: Some failures are expected with in-memory SQLite under high concurrency
// We consider the test successful if most requests succeed
successCount := 0
for _, code := range results {
if code == http.StatusOK {
successCount++
}
}
// At least 7 out of 10 should succeed (70%)
assert.GreaterOrEqual(t, successCount, 7, "Most concurrent requests should succeed")
}