mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-27 03:35:10 +00:00
631 lines
16 KiB
Go
631 lines
16 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// setupTestDB creates an in-memory SQLite database for testing
|
|
func setupTestDB(t *testing.T) *DeviceStore {
|
|
t.Helper()
|
|
// Use a named in-memory DB unique to this test to ensure isolation between tests
|
|
safeName := strings.ReplaceAll(t.Name(), "/", "_")
|
|
db, err := InitDB(fmt.Sprintf("file:%s?mode=memory&cache=shared", safeName), LibsqlConfig{})
|
|
if err != nil {
|
|
t.Fatalf("Failed to initialize test database: %v", err)
|
|
}
|
|
|
|
// Limit to single connection to avoid race conditions in tests
|
|
db.SetMaxOpenConns(1)
|
|
|
|
t.Cleanup(func() {
|
|
db.Close()
|
|
})
|
|
return NewDeviceStore(db)
|
|
}
|
|
|
|
// createTestUser creates a test user in the database
|
|
func createTestUser(t *testing.T, store *DeviceStore, did, handle string) {
|
|
t.Helper()
|
|
_, err := store.db.Exec(`
|
|
INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen)
|
|
VALUES (?, ?, ?, datetime('now'))
|
|
`, did, handle, "https://pds.example.com")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test user: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDevice_Struct(t *testing.T) {
|
|
device := &Device{
|
|
DID: "did:plc:test",
|
|
Handle: "alice.bsky.social",
|
|
Name: "My Device",
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
if device.DID != "did:plc:test" {
|
|
t.Errorf("Expected DID, got %q", device.DID)
|
|
}
|
|
}
|
|
|
|
func TestGenerateUserCode(t *testing.T) {
|
|
// Generate multiple codes to test
|
|
codes := make(map[string]bool)
|
|
for range 100 {
|
|
code := generateUserCode()
|
|
|
|
// Test format: XXXX-XXXX
|
|
if len(code) != 9 {
|
|
t.Errorf("Expected code length 9, got %d for code %q", len(code), code)
|
|
}
|
|
|
|
if code[4] != '-' {
|
|
t.Errorf("Expected hyphen at position 4, got %q", string(code[4]))
|
|
}
|
|
|
|
// Test valid characters (A-Z, 2-9, no ambiguous chars)
|
|
validChars := "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
|
parts := strings.Split(code, "-")
|
|
if len(parts) != 2 {
|
|
t.Errorf("Expected 2 parts separated by hyphen, got %d", len(parts))
|
|
}
|
|
|
|
for _, part := range parts {
|
|
for _, ch := range part {
|
|
if !strings.ContainsRune(validChars, ch) {
|
|
t.Errorf("Invalid character %q in code %q", ch, code)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test uniqueness (should be very rare to get duplicates)
|
|
if codes[code] {
|
|
t.Logf("Warning: duplicate code generated: %q (rare but possible)", code)
|
|
}
|
|
codes[code] = true
|
|
}
|
|
|
|
// Verify we got mostly unique codes (at least 95%)
|
|
if len(codes) < 95 {
|
|
t.Errorf("Expected at least 95 unique codes out of 100, got %d", len(codes))
|
|
}
|
|
}
|
|
|
|
func TestGenerateUserCode_Format(t *testing.T) {
|
|
code := generateUserCode()
|
|
|
|
// Test exact format
|
|
if len(code) != 9 {
|
|
t.Fatal("Code must be exactly 9 characters")
|
|
}
|
|
|
|
if code[4] != '-' {
|
|
t.Fatal("Character at index 4 must be hyphen")
|
|
}
|
|
|
|
// Test no ambiguous characters (O, 0, I, 1, L)
|
|
ambiguous := "O01IL"
|
|
for _, ch := range code {
|
|
if strings.ContainsRune(ambiguous, ch) {
|
|
t.Errorf("Code contains ambiguous character %q: %s", ch, code)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_CreatePendingAuth tests creating pending authorization
|
|
func TestDeviceStore_CreatePendingAuth(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
if pending.DeviceCode == "" {
|
|
t.Error("DeviceCode should not be empty")
|
|
}
|
|
if pending.UserCode == "" {
|
|
t.Error("UserCode should not be empty")
|
|
}
|
|
if pending.DeviceName != "My Device" {
|
|
t.Errorf("DeviceName = %v, want My Device", pending.DeviceName)
|
|
}
|
|
if pending.IPAddress != "192.168.1.1" {
|
|
t.Errorf("IPAddress = %v, want 192.168.1.1", pending.IPAddress)
|
|
}
|
|
if pending.UserAgent != "Test Agent" {
|
|
t.Errorf("UserAgent = %v, want Test Agent", pending.UserAgent)
|
|
}
|
|
if pending.ExpiresAt.Before(time.Now()) {
|
|
t.Error("ExpiresAt should be in the future")
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_GetPendingByUserCode tests retrieving pending auth by user code
|
|
func TestDeviceStore_GetPendingByUserCode(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
// Create pending auth
|
|
created, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
userCode string
|
|
wantFound bool
|
|
}{
|
|
{
|
|
name: "existing user code",
|
|
userCode: created.UserCode,
|
|
wantFound: true,
|
|
},
|
|
{
|
|
name: "non-existent user code",
|
|
userCode: "AAAA-BBBB",
|
|
wantFound: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
pending, found := store.GetPendingByUserCode(tt.userCode)
|
|
if found != tt.wantFound {
|
|
t.Errorf("GetPendingByUserCode() found = %v, want %v", found, tt.wantFound)
|
|
}
|
|
if tt.wantFound && pending == nil {
|
|
t.Error("Expected pending auth, got nil")
|
|
}
|
|
if tt.wantFound && pending != nil {
|
|
if pending.DeviceName != "My Device" {
|
|
t.Errorf("DeviceName = %v, want My Device", pending.DeviceName)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_GetPendingByDeviceCode tests retrieving pending auth by device code
|
|
func TestDeviceStore_GetPendingByDeviceCode(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
// Create pending auth
|
|
created, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
deviceCode string
|
|
wantFound bool
|
|
}{
|
|
{
|
|
name: "existing device code",
|
|
deviceCode: created.DeviceCode,
|
|
wantFound: true,
|
|
},
|
|
{
|
|
name: "non-existent device code",
|
|
deviceCode: "invalidcode",
|
|
wantFound: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
pending, found := store.GetPendingByDeviceCode(tt.deviceCode)
|
|
if found != tt.wantFound {
|
|
t.Errorf("GetPendingByDeviceCode() found = %v, want %v", found, tt.wantFound)
|
|
}
|
|
if tt.wantFound && pending == nil {
|
|
t.Error("Expected pending auth, got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_ApprovePending tests approving pending authorization
|
|
func TestDeviceStore_ApprovePending(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
// Create test users
|
|
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
createTestUser(t, store, "did:plc:bob123", "bob.bsky.social")
|
|
|
|
// Create pending auth
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
userCode string
|
|
did string
|
|
handle string
|
|
wantErr bool
|
|
errString string
|
|
}{
|
|
{
|
|
name: "successful approval",
|
|
userCode: pending.UserCode,
|
|
did: "did:plc:alice123",
|
|
handle: "alice.bsky.social",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "non-existent user code",
|
|
userCode: "AAAA-BBBB",
|
|
did: "did:plc:bob123",
|
|
handle: "bob.bsky.social",
|
|
wantErr: true,
|
|
errString: "not found",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
secret, err := store.ApprovePending(tt.userCode, tt.did, tt.handle)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ApprovePending() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !tt.wantErr {
|
|
if secret == "" {
|
|
t.Error("Expected device secret, got empty string")
|
|
}
|
|
if !strings.HasPrefix(secret, "atcr_device_") {
|
|
t.Errorf("Secret should start with atcr_device_, got %v", secret)
|
|
}
|
|
|
|
// Verify device was created
|
|
devices := store.ListDevices(tt.did)
|
|
if len(devices) != 1 {
|
|
t.Errorf("Expected 1 device, got %d", len(devices))
|
|
}
|
|
}
|
|
if tt.wantErr && tt.errString != "" && err != nil {
|
|
if !strings.Contains(err.Error(), tt.errString) {
|
|
t.Errorf("Error should contain %q, got %v", tt.errString, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_ApprovePending_AlreadyApproved tests double approval
|
|
func TestDeviceStore_ApprovePending_AlreadyApproved(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
// First approval
|
|
_, err = store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("First ApprovePending() error = %v", err)
|
|
}
|
|
|
|
// Second approval should fail
|
|
_, err = store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
|
|
if err == nil {
|
|
t.Error("Expected error for double approval, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "already approved") {
|
|
t.Errorf("Error should contain 'already approved', got %v", err)
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_ValidateDeviceSecret tests device secret validation
|
|
func TestDeviceStore_ValidateDeviceSecret(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Create and approve a device
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("ApprovePending() error = %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
secret string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid secret",
|
|
secret: secret,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid secret",
|
|
secret: "atcr_device_invalid",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty secret",
|
|
secret: "",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
device, err := store.ValidateDeviceSecret(tt.secret)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("ValidateDeviceSecret() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !tt.wantErr {
|
|
if device.DID != "did:plc:alice123" {
|
|
t.Errorf("DID = %v, want did:plc:alice123", device.DID)
|
|
}
|
|
if device.Name != "My Device" {
|
|
t.Errorf("Name = %v, want My Device", device.Name)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_ListDevices tests listing devices
|
|
func TestDeviceStore_ListDevices(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
did := "did:plc:alice123"
|
|
createTestUser(t, store, did, "alice.bsky.social")
|
|
|
|
// Initially empty
|
|
devices := store.ListDevices(did)
|
|
if len(devices) != 0 {
|
|
t.Errorf("Expected 0 devices initially, got %d", len(devices))
|
|
}
|
|
|
|
// Create 3 devices
|
|
for i := range 3 {
|
|
pending, err := store.CreatePendingAuth("Device "+string(rune('A'+i)), "192.168.1.1", "Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
_, err = store.ApprovePending(pending.UserCode, did, "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("ApprovePending() error = %v", err)
|
|
}
|
|
}
|
|
|
|
// List devices
|
|
devices = store.ListDevices(did)
|
|
if len(devices) != 3 {
|
|
t.Errorf("Expected 3 devices, got %d", len(devices))
|
|
}
|
|
|
|
// Verify they're sorted by created_at DESC (newest first)
|
|
for i := range len(devices) - 1 {
|
|
if devices[i].CreatedAt.Before(devices[i+1].CreatedAt) {
|
|
t.Error("Devices should be sorted by created_at DESC")
|
|
}
|
|
}
|
|
|
|
// List devices for different DID
|
|
otherDevices := store.ListDevices("did:plc:bob123")
|
|
if len(otherDevices) != 0 {
|
|
t.Errorf("Expected 0 devices for different DID, got %d", len(otherDevices))
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_RevokeDevice tests revoking a device
|
|
func TestDeviceStore_RevokeDevice(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
did := "did:plc:alice123"
|
|
createTestUser(t, store, did, "alice.bsky.social")
|
|
|
|
// Create device
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
_, err = store.ApprovePending(pending.UserCode, did, "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("ApprovePending() error = %v", err)
|
|
}
|
|
|
|
devices := store.ListDevices(did)
|
|
if len(devices) != 1 {
|
|
t.Fatalf("Expected 1 device, got %d", len(devices))
|
|
}
|
|
deviceID := devices[0].ID
|
|
|
|
tests := []struct {
|
|
name string
|
|
did string
|
|
deviceID string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "successful revocation",
|
|
did: did,
|
|
deviceID: deviceID,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "non-existent device",
|
|
did: did,
|
|
deviceID: "non-existent-id",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "wrong DID",
|
|
did: "did:plc:bob123",
|
|
deviceID: deviceID,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := store.RevokeDevice(tt.did, tt.deviceID)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("RevokeDevice() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Verify device was removed (after first successful test)
|
|
devices = store.ListDevices(did)
|
|
if len(devices) != 0 {
|
|
t.Errorf("Expected 0 devices after revocation, got %d", len(devices))
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_UpdateLastUsed tests updating last used timestamp
|
|
func TestDeviceStore_UpdateLastUsed(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
// Create device
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("ApprovePending() error = %v", err)
|
|
}
|
|
|
|
// Get device to get secret hash
|
|
device, err := store.ValidateDeviceSecret(secret)
|
|
if err != nil {
|
|
t.Fatalf("ValidateDeviceSecret() error = %v", err)
|
|
}
|
|
|
|
initialLastUsed := device.LastUsed
|
|
|
|
// Wait a bit to ensure timestamp difference
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
// Update last used
|
|
store.UpdateLastUsed(device.SecretHash)
|
|
|
|
// Verify it was updated
|
|
device2, err := store.ValidateDeviceSecret(secret)
|
|
if err != nil {
|
|
t.Fatalf("ValidateDeviceSecret() error = %v", err)
|
|
}
|
|
|
|
if !device2.LastUsed.After(initialLastUsed) {
|
|
t.Error("LastUsed should be updated to later time")
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_CleanupExpired tests cleanup of expired pending auths
|
|
func TestDeviceStore_CleanupExpired(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
// Create pending auth with manual expiration time
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
// Manually update expiration to the past
|
|
_, err = store.db.Exec(`
|
|
UPDATE pending_device_auth
|
|
SET expires_at = datetime('now', '-1 hour')
|
|
WHERE device_code = ?
|
|
`, pending.DeviceCode)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update expiration: %v", err)
|
|
}
|
|
|
|
// Run cleanup
|
|
store.CleanupExpired()
|
|
|
|
// Verify it was deleted
|
|
_, found := store.GetPendingByDeviceCode(pending.DeviceCode)
|
|
if found {
|
|
t.Error("Expired pending auth should have been cleaned up")
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_CleanupExpiredContext tests context-aware cleanup
|
|
func TestDeviceStore_CleanupExpiredContext(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
|
|
// Create and expire pending auth
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
_, err = store.db.Exec(`
|
|
UPDATE pending_device_auth
|
|
SET expires_at = datetime('now', '-1 hour')
|
|
WHERE device_code = ?
|
|
`, pending.DeviceCode)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update expiration: %v", err)
|
|
}
|
|
|
|
// Run context-aware cleanup
|
|
ctx := context.Background()
|
|
err = store.CleanupExpiredContext(ctx)
|
|
if err != nil {
|
|
t.Errorf("CleanupExpiredContext() error = %v", err)
|
|
}
|
|
|
|
// Verify it was deleted
|
|
_, found := store.GetPendingByDeviceCode(pending.DeviceCode)
|
|
if found {
|
|
t.Error("Expired pending auth should have been cleaned up")
|
|
}
|
|
}
|
|
|
|
// TestDeviceStore_SecretHashing tests bcrypt hashing
|
|
func TestDeviceStore_SecretHashing(t *testing.T) {
|
|
store := setupTestDB(t)
|
|
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
|
|
|
|
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
|
|
if err != nil {
|
|
t.Fatalf("CreatePendingAuth() error = %v", err)
|
|
}
|
|
|
|
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
|
|
if err != nil {
|
|
t.Fatalf("ApprovePending() error = %v", err)
|
|
}
|
|
|
|
// Get device via ValidateDeviceSecret to access secret hash
|
|
device, err := store.ValidateDeviceSecret(secret)
|
|
if err != nil {
|
|
t.Fatalf("ValidateDeviceSecret() error = %v", err)
|
|
}
|
|
|
|
// Verify bcrypt hash is valid
|
|
err = bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte(secret))
|
|
if err != nil {
|
|
t.Error("Secret hash should match secret")
|
|
}
|
|
|
|
// Verify wrong secret doesn't match
|
|
err = bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte("wrong_secret"))
|
|
if err == nil {
|
|
t.Error("Wrong secret should not match hash")
|
|
}
|
|
}
|