Files
at-container-registry/pkg/auth/token/handler_test.go

645 lines
18 KiB
Go

package token
import (
"context"
"crypto/tls"
"database/sql"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"atcr.io/pkg/appview/db"
)
// Shared test key to avoid generating a new RSA key for each test
// Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests
var (
sharedTestKeyPath string
sharedTestKeyOnce sync.Once
sharedTestKeyDir string
)
// getSharedTestKey returns a shared RSA key and its file path for all tests
// The key is generated once and reused across all tests in this package
func getSharedTestKey(t *testing.T) string {
sharedTestKeyOnce.Do(func() {
// Create a persistent temp directory for the shared key
var err error
sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*")
if err != nil {
t.Fatalf("Failed to create test key directory: %v", err)
}
sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem")
// Generate the key once (this is the expensive operation we want to avoid repeating)
// This will also generate the certificate via NewIssuer
_, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("Failed to generate shared test key: %v", err)
}
})
return sharedTestKeyPath
}
// setupTestDeviceStore creates an in-memory SQLite database for testing
func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) {
testDB, err := db.InitDB(":memory:", db.LibsqlConfig{})
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
return db.NewDeviceStore(testDB), testDB
}
// createTestDevice creates a device in the test database and returns its secret
// Requires both DeviceStore and sql.DB to insert user record first
func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string {
// First create a user record (required by foreign key constraint)
user := &db.User{
DID: did,
Handle: handle,
PDSEndpoint: "https://pds.example.com",
}
err := db.UpsertUser(testDB, user)
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Create pending authorization
pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent")
if err != nil {
t.Fatalf("Failed to create pending auth: %v", err)
}
// Approve the pending authorization
secret, err := store.ApprovePending(pending.UserCode, did, handle)
if err != nil {
t.Fatalf("Failed to approve pending auth: %v", err)
}
return secret
}
func TestNewHandler(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
handler := NewHandler(issuer, nil)
if handler == nil {
t.Fatal("Expected non-nil handler")
}
if handler.issuer == nil {
t.Error("Expected issuer to be set")
}
if handler.validator == nil {
t.Error("Expected validator to be initialized")
}
}
func TestHandler_SetPostAuthCallback(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
handler := NewHandler(issuer, nil)
handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
return nil
})
if handler.postAuthCallback == nil {
t.Error("Expected post-auth callback to be set")
}
}
func TestHandler_ServeHTTP_NoAuth(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
handler := NewHandler(issuer, nil)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
// Check for WWW-Authenticate header
if w.Header().Get("WWW-Authenticate") == "" {
t.Error("Expected WWW-Authenticate header")
}
}
func TestHandler_ServeHTTP_WrongMethod(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
handler := NewHandler(issuer, nil)
// Try POST instead of GET
req := httptest.NewRequest(http.MethodPost, "/auth/token", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
}
func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
// Create real device store with in-memory database
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Create request with device secret
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil)
req.SetBasicAuth("alice.bsky.social", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Parse response
var resp TokenResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.Token == "" {
t.Error("Expected non-empty token")
}
if resp.AccessToken == "" {
t.Error("Expected non-empty access_token")
}
if resp.ExpiresIn == 0 {
t.Error("Expected non-zero expires_in")
}
// Verify token and access_token are the same
if resp.Token != resp.AccessToken {
t.Error("Expected token and access_token to be the same")
}
}
func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
// Create device store but don't add any devices
deviceStore, _ := setupTestDeviceStore(t)
handler := NewHandler(issuer, deviceStore)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
req.SetBasicAuth("alice", "atcr_device_invalid")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
}
}
func TestHandler_ServeHTTP_InvalidScope(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Invalid scope format (missing colons)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
}
body := w.Body.String()
if !strings.Contains(body, "invalid scope") {
t.Errorf("Expected error message to contain 'invalid scope', got: %s", body)
}
}
func TestHandler_ServeHTTP_AccessDenied(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Try to push to someone else's repository
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code)
}
body := w.Body.String()
if !strings.Contains(body, "access denied") {
t.Errorf("Expected error message to contain 'access denied', got: %s", body)
}
}
func TestHandler_ServeHTTP_WithCallback(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Set callback to track if it's called
callbackCalled := false
handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
callbackCalled = true
// Note: We don't check the values because callback shouldn't be called for device auth
return nil
})
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Note: Callback is only called for app password auth, not device auth
// So callbackCalled should be false for this test
if callbackCalled {
t.Error("Expected callback NOT to be called for device auth")
}
}
func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Multiple scopes separated by space (URL encoded)
scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush"
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
}
}
func TestHandler_ServeHTTP_WildcardScope(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Wildcard scope should be allowed
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
}
}
func TestHandler_ServeHTTP_NoScope(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// No scope parameter - should still work (empty access)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
}
var resp TokenResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if resp.Token == "" {
t.Error("Expected non-empty token even with no scope")
}
}
func TestGetBaseURL(t *testing.T) {
tests := []struct {
name string
host string
headers map[string]string
expectedURL string
}{
{
name: "simple host",
host: "registry.example.com",
headers: map[string]string{},
expectedURL: "http://registry.example.com",
},
{
name: "with TLS",
host: "registry.example.com",
headers: map[string]string{},
expectedURL: "https://registry.example.com", // Would need TLS in request
},
{
name: "with X-Forwarded-Host",
host: "internal-host",
headers: map[string]string{
"X-Forwarded-Host": "registry.example.com",
},
expectedURL: "http://registry.example.com",
},
{
name: "with X-Forwarded-Proto",
host: "registry.example.com",
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedURL: "https://registry.example.com",
},
{
name: "with both forwarded headers",
host: "internal",
headers: map[string]string{
"X-Forwarded-Host": "registry.example.com",
"X-Forwarded-Proto": "https",
},
expectedURL: "https://registry.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = tt.host
for key, value := range tt.headers {
req.Header.Set(key, value)
}
// For TLS test
if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 {
req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS
}
baseURL := getBaseURL(req)
if baseURL != tt.expectedURL {
t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL)
}
})
}
}
func TestTokenResponse_JSONFormat(t *testing.T) {
resp := TokenResponse{
Token: "jwt_token_here",
AccessToken: "jwt_token_here",
ExpiresIn: 900,
IssuedAt: "2025-01-01T00:00:00Z",
}
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Failed to marshal response: %v", err)
}
// Verify JSON structure
var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
if decoded["token"] != "jwt_token_here" {
t.Error("Expected token field in JSON")
}
if decoded["access_token"] != "jwt_token_here" {
t.Error("Expected access_token field in JSON")
}
if decoded["expires_in"] != float64(900) {
t.Error("Expected expires_in field in JSON")
}
if decoded["issued_at"] != "2025-01-01T00:00:00Z" {
t.Error("Expected issued_at field in JSON")
}
}
func TestHandler_ServeHTTP_AuthHeader(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
handler := NewHandler(issuer, nil)
// Test with manually constructed auth header
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
auth := base64.StdEncoding.EncodeToString([]byte("username:password"))
req.Header.Set("Authorization", "Basic "+auth)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Should fail because we don't have valid credentials, but we're testing the header parsing
if w.Code != http.StatusUnauthorized {
t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code)
}
}
func TestHandler_ServeHTTP_ContentType(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code)
}
contentType := w.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type 'application/json', got %q", contentType)
}
}
func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) {
keyPath := getSharedTestKey(t)
// Create issuer with specific expiration
expiration := 10 * time.Minute
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
var resp TokenResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
expectedExpiresIn := int(expiration.Seconds())
if resp.ExpiresIn != expectedExpiresIn {
t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn)
}
}
func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) {
keyPath := getSharedTestKey(t)
issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
if err != nil {
t.Fatalf("NewIssuer() error = %v", err)
}
deviceStore, database := setupTestDeviceStore(t)
deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
handler := NewHandler(issuer, deviceStore)
// Pull from someone else's repo should be allowed
req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil)
req.SetBasicAuth("alice", deviceSecret)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
}
}