package oauth import ( "context" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/bluesky-social/indigo/atproto/auth/oauth" ) func TestNewServer(t *testing.T) { // Create a basic OAuth app for testing store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) if server == nil { t.Fatal("Expected non-nil server") } if server.clientApp == nil { t.Error("Expected clientApp to be set") } } func TestServer_SetRefresher(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) refresher := NewRefresher(clientApp) server.SetRefresher(refresher) if server.refresher == nil { t.Error("Expected refresher to be set") } } func TestServer_SetPostAuthCallback(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) // Set callback with correct signature server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error { return nil }) if server.postAuthCallback == nil { t.Error("Expected post-auth callback to be set") } } func TestServer_SetUISessionStore(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) mockStore := &mockUISessionStore{} server.SetUISessionStore(mockStore) if server.uiSessionStore == nil { t.Error("Expected UI session store to be set") } } // Mock implementations for testing type mockUISessionStore struct { createFunc func(did, handle, pdsEndpoint string, duration time.Duration) (string, error) createWithOAuthFunc func(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) deleteByDIDFunc func(did string) } func (m *mockUISessionStore) Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error) { if m.createFunc != nil { return m.createFunc(did, handle, pdsEndpoint, duration) } return "mock-session-id", nil } func (m *mockUISessionStore) CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) { if m.createWithOAuthFunc != nil { return m.createWithOAuthFunc(did, handle, pdsEndpoint, oauthSessionID, duration) } return "mock-session-id-with-oauth", nil } func (m *mockUISessionStore) DeleteByDID(did string) { if m.deleteByDIDFunc != nil { m.deleteByDIDFunc(did) } } // ServeAuthorize tests func TestServer_ServeAuthorize_MissingHandle(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil) w := httptest.NewRecorder() server.ServeAuthorize(w, req) resp := w.Result() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) } } func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil) w := httptest.NewRecorder() server.ServeAuthorize(w, req) resp := w.Result() if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode) } } // ServeCallback tests func TestServer_ServeCallback_InvalidMethod(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil) w := httptest.NewRecorder() server.ServeCallback(w, req) resp := w.Result() if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode) } } func TestServer_ServeCallback_OAuthError(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil) w := httptest.NewRecorder() server.ServeCallback(w, req) resp := w.Result() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) } body := w.Body.String() if !strings.Contains(body, "access_denied") { t.Errorf("Expected error message to contain 'access_denied', got: %s", body) } } func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) callbackInvoked := false server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error { callbackInvoked = true // Note: We can't verify the exact DID here since we're not running a full OAuth flow // This test verifies that the callback mechanism works return nil }) // Verify callback is set if server.postAuthCallback == nil { t.Error("Expected post-auth callback to be set") } // For this test, we're verifying the callback is configured correctly // A full integration test would require mocking the entire OAuth flow if callbackInvoked { t.Error("Callback should not be invoked without OAuth completion") } } func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) { sessionCreated := false uiStore := &mockUISessionStore{ createWithOAuthFunc: func(d, h, pds, oauthSessionID string, duration time.Duration) (string, error) { sessionCreated = true return "ui-session-123", nil }, } store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) server.SetUISessionStore(uiStore) // Verify UI session store is set if server.uiSessionStore == nil { t.Error("Expected UI session store to be set") } // For this test, we're verifying the UI session store is configured correctly // A full integration test would require mocking the entire OAuth flow with callback if sessionCreated { t.Error("Session should not be created without OAuth completion") } } func TestServer_RenderError(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) w := httptest.NewRecorder() server.renderError(w, "Test error message") resp := w.Result() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) } body := w.Body.String() if !strings.Contains(body, "Test error message") { t.Errorf("Expected error message in body, got: %s", body) } if !strings.Contains(body, "Authorization Failed") { t.Errorf("Expected 'Authorization Failed' title in body, got: %s", body) } } func TestServer_RenderRedirectToSettings(t *testing.T) { store := oauth.NewMemStore() scopes := GetDefaultScopes("*") clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry") if err != nil { t.Fatalf("NewClientApp() error = %v", err) } server := NewServer(clientApp) w := httptest.NewRecorder() server.renderRedirectToSettings(w, "alice.bsky.social") resp := w.Result() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status %d, got %d", http.StatusOK, resp.StatusCode) } body := w.Body.String() if !strings.Contains(body, "alice.bsky.social") { t.Errorf("Expected handle in body, got: %s", body) } if !strings.Contains(body, "Authorization Successful") { t.Errorf("Expected 'Authorization Successful' title in body, got: %s", body) } if !strings.Contains(body, "/settings") { t.Errorf("Expected redirect to /settings in body, got: %s", body) } }