396 lines
11 KiB
Go
396 lines
11 KiB
Go
package middleware
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"atcr.io/pkg/appview/db"
|
|
)
|
|
|
|
func TestGetUser_NoContext(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
user := GetUser(req)
|
|
if user != nil {
|
|
t.Error("Expected nil user when no context is set")
|
|
}
|
|
}
|
|
|
|
// setupTestDB creates an in-memory SQLite database for testing
|
|
func setupTestDB(t *testing.T) *sql.DB {
|
|
database, err := db.InitDB(":memory:", true)
|
|
require.NoError(t, err)
|
|
|
|
t.Cleanup(func() {
|
|
database.Close()
|
|
})
|
|
|
|
return database
|
|
}
|
|
|
|
// TestRequireAuth_ValidSession tests RequireAuth with a valid session
|
|
func TestRequireAuth_ValidSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
// Create a user first (required by foreign key)
|
|
_, err := database.Exec(
|
|
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
|
|
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Create a session
|
|
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
|
|
require.NoError(t, err)
|
|
|
|
// Create a test handler that checks user context
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
user := GetUser(r)
|
|
assert.NotNil(t, user)
|
|
assert.Equal(t, "did:plc:test123", user.DID)
|
|
assert.Equal(t, "alice.bsky.social", user.Handle)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Wrap with RequireAuth middleware
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Create request with session cookie
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: sessionID,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, handlerCalled, "handler should have been called")
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// TestRequireAuth_MissingSession tests RequireAuth redirects when no session
|
|
func TestRequireAuth_MissingSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Request without session cookie
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.False(t, handlerCalled, "handler should not have been called")
|
|
assert.Equal(t, http.StatusFound, w.Code)
|
|
assert.Contains(t, w.Header().Get("Location"), "/auth/oauth/login")
|
|
assert.Contains(t, w.Header().Get("Location"), "return_to=%2Fprotected")
|
|
}
|
|
|
|
// TestRequireAuth_InvalidSession tests RequireAuth redirects when session is invalid
|
|
func TestRequireAuth_InvalidSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Request with invalid session ID
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: "invalid-session-id",
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.False(t, handlerCalled, "handler should not have been called")
|
|
assert.Equal(t, http.StatusFound, w.Code)
|
|
assert.Contains(t, w.Header().Get("Location"), "/auth/oauth/login")
|
|
}
|
|
|
|
// TestRequireAuth_WithQueryParams tests RequireAuth preserves query parameters in return_to
|
|
func TestRequireAuth_WithQueryParams(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Request without session but with query parameters
|
|
req := httptest.NewRequest("GET", "/protected?foo=bar&baz=qux", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusFound, w.Code)
|
|
location := w.Header().Get("Location")
|
|
assert.Contains(t, location, "/auth/oauth/login")
|
|
assert.Contains(t, location, "return_to=")
|
|
// Query parameters should be preserved in return_to
|
|
assert.Contains(t, location, "foo%3Dbar")
|
|
}
|
|
|
|
// TestRequireAuth_DatabaseFallback tests fallback to session data when DB lookup has no avatar
|
|
func TestRequireAuth_DatabaseFallback(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
// Create a user without avatar (required by foreign key)
|
|
_, err := database.Exec(
|
|
"INSERT INTO users (did, handle, pds_endpoint, last_seen, avatar) VALUES (?, ?, ?, ?, ?)",
|
|
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(), "",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Create a session
|
|
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
|
|
require.NoError(t, err)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
user := GetUser(r)
|
|
assert.NotNil(t, user)
|
|
assert.Equal(t, "did:plc:test123", user.DID)
|
|
assert.Equal(t, "alice.bsky.social", user.Handle)
|
|
// User exists in DB but has no avatar - should use DB version
|
|
assert.Empty(t, user.Avatar, "avatar should be empty when not set in DB")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: sessionID,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, handlerCalled)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// TestOptionalAuth_ValidSession tests OptionalAuth with valid session
|
|
func TestOptionalAuth_ValidSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
// Create a user first (required by foreign key)
|
|
_, err := database.Exec(
|
|
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
|
|
"did:plc:test123", "alice.bsky.social", "https://pds.example.com", time.Now(),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Create a session
|
|
sessionID, err := store.Create("did:plc:test123", "alice.bsky.social", "https://pds.example.com", 24*time.Hour)
|
|
require.NoError(t, err)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
user := GetUser(r)
|
|
assert.NotNil(t, user, "user should be set when session is valid")
|
|
assert.Equal(t, "did:plc:test123", user.DID)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := OptionalAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: sessionID,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, handlerCalled)
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// TestOptionalAuth_NoSession tests OptionalAuth continues without user when no session
|
|
func TestOptionalAuth_NoSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
user := GetUser(r)
|
|
assert.Nil(t, user, "user should be nil when no session")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := OptionalAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Request without session cookie
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, handlerCalled, "handler should still be called")
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// TestOptionalAuth_InvalidSession tests OptionalAuth continues without user when session invalid
|
|
func TestOptionalAuth_InvalidSession(t *testing.T) {
|
|
database := setupTestDB(t)
|
|
store := db.NewSessionStore(database)
|
|
|
|
handlerCalled := false
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
handlerCalled = true
|
|
user := GetUser(r)
|
|
assert.Nil(t, user, "user should be nil when session is invalid")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
middleware := OptionalAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Request with invalid session ID
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: "invalid-session-id",
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
assert.True(t, handlerCalled, "handler should still be called")
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
// TestMiddleware_ConcurrentAccess tests concurrent requests through middleware
|
|
func TestMiddleware_ConcurrentAccess(t *testing.T) {
|
|
// Use a shared in-memory database for concurrent access
|
|
// (SQLite's default :memory: creates separate DBs per connection)
|
|
database, err := db.InitDB("file::memory:?cache=shared", true)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
database.Close()
|
|
})
|
|
|
|
store := db.NewSessionStore(database)
|
|
|
|
// Pre-create all users and sessions before concurrent access
|
|
// This ensures database is fully initialized before goroutines start
|
|
sessionIDs := make([]string, 10)
|
|
for i := 0; i < 10; i++ {
|
|
did := fmt.Sprintf("did:plc:user%d", i)
|
|
handle := fmt.Sprintf("user%d.bsky.social", i)
|
|
|
|
// Create user first
|
|
_, err := database.Exec(
|
|
"INSERT INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, ?)",
|
|
did, handle, "https://pds.example.com", time.Now(),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// Create session
|
|
sessionID, err := store.Create(
|
|
did,
|
|
handle,
|
|
"https://pds.example.com",
|
|
24*time.Hour,
|
|
)
|
|
require.NoError(t, err)
|
|
sessionIDs[i] = sessionID
|
|
}
|
|
|
|
// All setup complete - now test concurrent access
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user := GetUser(r)
|
|
if user != nil {
|
|
w.WriteHeader(http.StatusOK)
|
|
} else {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}
|
|
})
|
|
|
|
middleware := RequireAuth(store, database)
|
|
wrappedHandler := middleware(handler)
|
|
|
|
// Collect results from all goroutines
|
|
results := make([]int, 10)
|
|
var wg sync.WaitGroup
|
|
var mu sync.Mutex // Protect results map
|
|
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func(index int, sessionID string) {
|
|
defer wg.Done()
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "atcr_session",
|
|
Value: sessionID,
|
|
})
|
|
w := httptest.NewRecorder()
|
|
|
|
wrappedHandler.ServeHTTP(w, req)
|
|
|
|
mu.Lock()
|
|
results[index] = w.Code
|
|
mu.Unlock()
|
|
}(i, sessionIDs[i])
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Check all results after concurrent execution
|
|
// Note: Some failures are expected with in-memory SQLite under high concurrency
|
|
// We consider the test successful if most requests succeed
|
|
successCount := 0
|
|
for _, code := range results {
|
|
if code == http.StatusOK {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
// At least 7 out of 10 should succeed (70%)
|
|
assert.GreaterOrEqual(t, successCount, 7, "Most concurrent requests should succeed")
|
|
}
|