Files
at-container-registry/pkg/auth/oauth/server_test.go

349 lines
9.6 KiB
Go

package oauth
import (
"context"
"github.com/bluesky-social/indigo/atproto/auth/oauth"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewServer(t *testing.T) {
// Create a basic OAuth app for testing
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
if server == nil {
t.Fatal("Expected non-nil server")
}
if server.clientApp == nil {
t.Error("Expected clientApp to be set")
}
}
func TestServer_SetRefresher(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
refresher := NewRefresher(clientApp)
server.SetRefresher(refresher)
if server.refresher == nil {
t.Error("Expected refresher to be set")
}
}
func TestServer_SetPostAuthCallback(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
// Set callback with correct signature
server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error {
return nil
})
if server.postAuthCallback == nil {
t.Error("Expected post-auth callback to be set")
}
}
func TestServer_SetUISessionStore(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
mockStore := &mockUISessionStore{}
server.SetUISessionStore(mockStore)
if server.uiSessionStore == nil {
t.Error("Expected UI session store to be set")
}
}
// Mock implementations for testing
type mockUISessionStore struct {
createFunc func(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
createWithOAuthFunc func(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error)
deleteByDIDFunc func(did string)
}
func (m *mockUISessionStore) Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error) {
if m.createFunc != nil {
return m.createFunc(did, handle, pdsEndpoint, duration)
}
return "mock-session-id", nil
}
func (m *mockUISessionStore) CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) {
if m.createWithOAuthFunc != nil {
return m.createWithOAuthFunc(did, handle, pdsEndpoint, oauthSessionID, duration)
}
return "mock-session-id-with-oauth", nil
}
func (m *mockUISessionStore) DeleteByDID(did string) {
if m.deleteByDIDFunc != nil {
m.deleteByDIDFunc(did)
}
}
type mockRefresher struct {
invalidateSessionFunc func(did string)
}
func (m *mockRefresher) InvalidateSession(did string) {
if m.invalidateSessionFunc != nil {
m.invalidateSessionFunc(did)
}
}
// ServeAuthorize tests
func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil)
w := httptest.NewRecorder()
server.ServeAuthorize(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
}
func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil)
w := httptest.NewRecorder()
server.ServeAuthorize(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode)
}
}
// ServeCallback tests
func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil)
w := httptest.NewRecorder()
server.ServeCallback(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode)
}
}
func TestServer_ServeCallback_OAuthError(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil)
w := httptest.NewRecorder()
server.ServeCallback(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
body := w.Body.String()
if !strings.Contains(body, "access_denied") {
t.Errorf("Expected error message to contain 'access_denied', got: %s", body)
}
}
func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
callbackInvoked := false
server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error {
callbackInvoked = true
// Note: We can't verify the exact DID here since we're not running a full OAuth flow
// This test verifies that the callback mechanism works
return nil
})
// Verify callback is set
if server.postAuthCallback == nil {
t.Error("Expected post-auth callback to be set")
}
// For this test, we're verifying the callback is configured correctly
// A full integration test would require mocking the entire OAuth flow
if callbackInvoked {
t.Error("Callback should not be invoked without OAuth completion")
}
}
func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
sessionCreated := false
uiStore := &mockUISessionStore{
createWithOAuthFunc: func(d, h, pds, oauthSessionID string, duration time.Duration) (string, error) {
sessionCreated = true
return "ui-session-123", nil
},
}
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
server.SetUISessionStore(uiStore)
// Verify UI session store is set
if server.uiSessionStore == nil {
t.Error("Expected UI session store to be set")
}
// For this test, we're verifying the UI session store is configured correctly
// A full integration test would require mocking the entire OAuth flow with callback
if sessionCreated {
t.Error("Session should not be created without OAuth completion")
}
}
func TestServer_RenderError(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
w := httptest.NewRecorder()
server.renderError(w, "Test error message")
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
body := w.Body.String()
if !strings.Contains(body, "Test error message") {
t.Errorf("Expected error message in body, got: %s", body)
}
if !strings.Contains(body, "Authorization Failed") {
t.Errorf("Expected 'Authorization Failed' title in body, got: %s", body)
}
}
func TestServer_RenderRedirectToSettings(t *testing.T) {
store := oauth.NewMemStore()
scopes := GetDefaultScopes("*")
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
if err != nil {
t.Fatalf("NewClientApp() error = %v", err)
}
server := NewServer(clientApp)
w := httptest.NewRecorder()
server.renderRedirectToSettings(w, "alice.bsky.social")
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, resp.StatusCode)
}
body := w.Body.String()
if !strings.Contains(body, "alice.bsky.social") {
t.Errorf("Expected handle in body, got: %s", body)
}
if !strings.Contains(body, "Authorization Successful") {
t.Errorf("Expected 'Authorization Successful' title in body, got: %s", body)
}
if !strings.Contains(body, "/settings") {
t.Errorf("Expected redirect to /settings in body, got: %s", body)
}
}