349 lines
9.6 KiB
Go
349 lines
9.6 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewServer(t *testing.T) {
|
|
// Create a basic OAuth app for testing
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
if server == nil {
|
|
t.Fatal("Expected non-nil server")
|
|
}
|
|
|
|
if server.clientApp == nil {
|
|
t.Error("Expected clientApp to be set")
|
|
}
|
|
}
|
|
|
|
func TestServer_SetRefresher(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
refresher := NewRefresher(clientApp)
|
|
|
|
server.SetRefresher(refresher)
|
|
if server.refresher == nil {
|
|
t.Error("Expected refresher to be set")
|
|
}
|
|
}
|
|
|
|
func TestServer_SetPostAuthCallback(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
// Set callback with correct signature
|
|
server.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, sessionID string) error {
|
|
return nil
|
|
})
|
|
|
|
if server.postAuthCallback == nil {
|
|
t.Error("Expected post-auth callback to be set")
|
|
}
|
|
}
|
|
|
|
func TestServer_SetUISessionStore(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
mockStore := &mockUISessionStore{}
|
|
|
|
server.SetUISessionStore(mockStore)
|
|
if server.uiSessionStore == nil {
|
|
t.Error("Expected UI session store to be set")
|
|
}
|
|
}
|
|
|
|
// Mock implementations for testing
|
|
|
|
type mockUISessionStore struct {
|
|
createFunc func(did, handle, pdsEndpoint string, duration time.Duration) (string, error)
|
|
createWithOAuthFunc func(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error)
|
|
deleteByDIDFunc func(did string)
|
|
}
|
|
|
|
func (m *mockUISessionStore) Create(did, handle, pdsEndpoint string, duration time.Duration) (string, error) {
|
|
if m.createFunc != nil {
|
|
return m.createFunc(did, handle, pdsEndpoint, duration)
|
|
}
|
|
return "mock-session-id", nil
|
|
}
|
|
|
|
func (m *mockUISessionStore) CreateWithOAuth(did, handle, pdsEndpoint, oauthSessionID string, duration time.Duration) (string, error) {
|
|
if m.createWithOAuthFunc != nil {
|
|
return m.createWithOAuthFunc(did, handle, pdsEndpoint, oauthSessionID, duration)
|
|
}
|
|
return "mock-session-id-with-oauth", nil
|
|
}
|
|
|
|
func (m *mockUISessionStore) DeleteByDID(did string) {
|
|
if m.deleteByDIDFunc != nil {
|
|
m.deleteByDIDFunc(did)
|
|
}
|
|
}
|
|
|
|
type mockRefresher struct {
|
|
invalidateSessionFunc func(did string)
|
|
}
|
|
|
|
func (m *mockRefresher) InvalidateSession(did string) {
|
|
if m.invalidateSessionFunc != nil {
|
|
m.invalidateSessionFunc(did)
|
|
}
|
|
}
|
|
|
|
// ServeAuthorize tests
|
|
|
|
func TestServer_ServeAuthorize_MissingHandle(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/authorize", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
server.ServeAuthorize(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServer_ServeAuthorize_InvalidMethod(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/authorize?handle=alice.bsky.social", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
server.ServeAuthorize(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusMethodNotAllowed {
|
|
t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// ServeCallback tests
|
|
|
|
func TestServer_ServeCallback_InvalidMethod(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/auth/oauth/callback", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
server.ServeCallback(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusMethodNotAllowed {
|
|
t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestServer_ServeCallback_OAuthError(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/auth/oauth/callback?error=access_denied&error_description=User+denied+access", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
server.ServeCallback(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
body := w.Body.String()
|
|
if !strings.Contains(body, "access_denied") {
|
|
t.Errorf("Expected error message to contain 'access_denied', got: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestServer_ServeCallback_WithPostAuthCallback(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
callbackInvoked := false
|
|
server.SetPostAuthCallback(func(ctx context.Context, d, h, pds, sessionID string) error {
|
|
callbackInvoked = true
|
|
// Note: We can't verify the exact DID here since we're not running a full OAuth flow
|
|
// This test verifies that the callback mechanism works
|
|
return nil
|
|
})
|
|
|
|
// Verify callback is set
|
|
if server.postAuthCallback == nil {
|
|
t.Error("Expected post-auth callback to be set")
|
|
}
|
|
|
|
// For this test, we're verifying the callback is configured correctly
|
|
// A full integration test would require mocking the entire OAuth flow
|
|
if callbackInvoked {
|
|
t.Error("Callback should not be invoked without OAuth completion")
|
|
}
|
|
}
|
|
|
|
func TestServer_ServeCallback_UIFlow_SessionCreationLogic(t *testing.T) {
|
|
sessionCreated := false
|
|
uiStore := &mockUISessionStore{
|
|
createWithOAuthFunc: func(d, h, pds, oauthSessionID string, duration time.Duration) (string, error) {
|
|
sessionCreated = true
|
|
return "ui-session-123", nil
|
|
},
|
|
}
|
|
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
server.SetUISessionStore(uiStore)
|
|
|
|
// Verify UI session store is set
|
|
if server.uiSessionStore == nil {
|
|
t.Error("Expected UI session store to be set")
|
|
}
|
|
|
|
// For this test, we're verifying the UI session store is configured correctly
|
|
// A full integration test would require mocking the entire OAuth flow with callback
|
|
if sessionCreated {
|
|
t.Error("Session should not be created without OAuth completion")
|
|
}
|
|
}
|
|
|
|
func TestServer_RenderError(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
w := httptest.NewRecorder()
|
|
server.renderError(w, "Test error message")
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
|
}
|
|
|
|
body := w.Body.String()
|
|
if !strings.Contains(body, "Test error message") {
|
|
t.Errorf("Expected error message in body, got: %s", body)
|
|
}
|
|
|
|
if !strings.Contains(body, "Authorization Failed") {
|
|
t.Errorf("Expected 'Authorization Failed' title in body, got: %s", body)
|
|
}
|
|
}
|
|
|
|
func TestServer_RenderRedirectToSettings(t *testing.T) {
|
|
store := oauth.NewMemStore()
|
|
|
|
scopes := GetDefaultScopes("*")
|
|
clientApp, err := NewClientApp("http://localhost:5000", store, scopes, "", "AT Container Registry")
|
|
if err != nil {
|
|
t.Fatalf("NewClientApp() error = %v", err)
|
|
}
|
|
|
|
server := NewServer(clientApp)
|
|
|
|
w := httptest.NewRecorder()
|
|
server.renderRedirectToSettings(w, "alice.bsky.social")
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
body := w.Body.String()
|
|
if !strings.Contains(body, "alice.bsky.social") {
|
|
t.Errorf("Expected handle in body, got: %s", body)
|
|
}
|
|
|
|
if !strings.Contains(body, "Authorization Successful") {
|
|
t.Errorf("Expected 'Authorization Successful' title in body, got: %s", body)
|
|
}
|
|
|
|
if !strings.Contains(body, "/settings") {
|
|
t.Errorf("Expected redirect to /settings in body, got: %s", body)
|
|
}
|
|
}
|