Files
at-container-registry/pkg/appview/db/device_store_test.go
2026-02-09 23:19:01 -06:00

631 lines
16 KiB
Go

package db
import (
"context"
"fmt"
"strings"
"testing"
"time"
"golang.org/x/crypto/bcrypt"
)
// setupTestDB creates an in-memory SQLite database for testing
func setupTestDB(t *testing.T) *DeviceStore {
t.Helper()
// Use a named in-memory DB unique to this test to ensure isolation between tests
safeName := strings.ReplaceAll(t.Name(), "/", "_")
db, err := InitDB(fmt.Sprintf("file:%s?mode=memory&cache=shared", safeName), LibsqlConfig{})
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
// Limit to single connection to avoid race conditions in tests
db.SetMaxOpenConns(1)
t.Cleanup(func() {
db.Close()
})
return NewDeviceStore(db)
}
// createTestUser creates a test user in the database
func createTestUser(t *testing.T, store *DeviceStore, did, handle string) {
t.Helper()
_, err := store.db.Exec(`
INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen)
VALUES (?, ?, ?, datetime('now'))
`, did, handle, "https://pds.example.com")
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
}
func TestDevice_Struct(t *testing.T) {
device := &Device{
DID: "did:plc:test",
Handle: "alice.bsky.social",
Name: "My Device",
CreatedAt: time.Now(),
}
if device.DID != "did:plc:test" {
t.Errorf("Expected DID, got %q", device.DID)
}
}
func TestGenerateUserCode(t *testing.T) {
// Generate multiple codes to test
codes := make(map[string]bool)
for range 100 {
code := generateUserCode()
// Test format: XXXX-XXXX
if len(code) != 9 {
t.Errorf("Expected code length 9, got %d for code %q", len(code), code)
}
if code[4] != '-' {
t.Errorf("Expected hyphen at position 4, got %q", string(code[4]))
}
// Test valid characters (A-Z, 2-9, no ambiguous chars)
validChars := "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
parts := strings.Split(code, "-")
if len(parts) != 2 {
t.Errorf("Expected 2 parts separated by hyphen, got %d", len(parts))
}
for _, part := range parts {
for _, ch := range part {
if !strings.ContainsRune(validChars, ch) {
t.Errorf("Invalid character %q in code %q", ch, code)
}
}
}
// Test uniqueness (should be very rare to get duplicates)
if codes[code] {
t.Logf("Warning: duplicate code generated: %q (rare but possible)", code)
}
codes[code] = true
}
// Verify we got mostly unique codes (at least 95%)
if len(codes) < 95 {
t.Errorf("Expected at least 95 unique codes out of 100, got %d", len(codes))
}
}
func TestGenerateUserCode_Format(t *testing.T) {
code := generateUserCode()
// Test exact format
if len(code) != 9 {
t.Fatal("Code must be exactly 9 characters")
}
if code[4] != '-' {
t.Fatal("Character at index 4 must be hyphen")
}
// Test no ambiguous characters (O, 0, I, 1, L)
ambiguous := "O01IL"
for _, ch := range code {
if strings.ContainsRune(ambiguous, ch) {
t.Errorf("Code contains ambiguous character %q: %s", ch, code)
}
}
}
// TestDeviceStore_CreatePendingAuth tests creating pending authorization
func TestDeviceStore_CreatePendingAuth(t *testing.T) {
store := setupTestDB(t)
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
if pending.DeviceCode == "" {
t.Error("DeviceCode should not be empty")
}
if pending.UserCode == "" {
t.Error("UserCode should not be empty")
}
if pending.DeviceName != "My Device" {
t.Errorf("DeviceName = %v, want My Device", pending.DeviceName)
}
if pending.IPAddress != "192.168.1.1" {
t.Errorf("IPAddress = %v, want 192.168.1.1", pending.IPAddress)
}
if pending.UserAgent != "Test Agent" {
t.Errorf("UserAgent = %v, want Test Agent", pending.UserAgent)
}
if pending.ExpiresAt.Before(time.Now()) {
t.Error("ExpiresAt should be in the future")
}
}
// TestDeviceStore_GetPendingByUserCode tests retrieving pending auth by user code
func TestDeviceStore_GetPendingByUserCode(t *testing.T) {
store := setupTestDB(t)
// Create pending auth
created, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
tests := []struct {
name string
userCode string
wantFound bool
}{
{
name: "existing user code",
userCode: created.UserCode,
wantFound: true,
},
{
name: "non-existent user code",
userCode: "AAAA-BBBB",
wantFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pending, found := store.GetPendingByUserCode(tt.userCode)
if found != tt.wantFound {
t.Errorf("GetPendingByUserCode() found = %v, want %v", found, tt.wantFound)
}
if tt.wantFound && pending == nil {
t.Error("Expected pending auth, got nil")
}
if tt.wantFound && pending != nil {
if pending.DeviceName != "My Device" {
t.Errorf("DeviceName = %v, want My Device", pending.DeviceName)
}
}
})
}
}
// TestDeviceStore_GetPendingByDeviceCode tests retrieving pending auth by device code
func TestDeviceStore_GetPendingByDeviceCode(t *testing.T) {
store := setupTestDB(t)
// Create pending auth
created, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
tests := []struct {
name string
deviceCode string
wantFound bool
}{
{
name: "existing device code",
deviceCode: created.DeviceCode,
wantFound: true,
},
{
name: "non-existent device code",
deviceCode: "invalidcode",
wantFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pending, found := store.GetPendingByDeviceCode(tt.deviceCode)
if found != tt.wantFound {
t.Errorf("GetPendingByDeviceCode() found = %v, want %v", found, tt.wantFound)
}
if tt.wantFound && pending == nil {
t.Error("Expected pending auth, got nil")
}
})
}
}
// TestDeviceStore_ApprovePending tests approving pending authorization
func TestDeviceStore_ApprovePending(t *testing.T) {
store := setupTestDB(t)
// Create test users
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
createTestUser(t, store, "did:plc:bob123", "bob.bsky.social")
// Create pending auth
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
tests := []struct {
name string
userCode string
did string
handle string
wantErr bool
errString string
}{
{
name: "successful approval",
userCode: pending.UserCode,
did: "did:plc:alice123",
handle: "alice.bsky.social",
wantErr: false,
},
{
name: "non-existent user code",
userCode: "AAAA-BBBB",
did: "did:plc:bob123",
handle: "bob.bsky.social",
wantErr: true,
errString: "not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
secret, err := store.ApprovePending(tt.userCode, tt.did, tt.handle)
if (err != nil) != tt.wantErr {
t.Errorf("ApprovePending() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if secret == "" {
t.Error("Expected device secret, got empty string")
}
if !strings.HasPrefix(secret, "atcr_device_") {
t.Errorf("Secret should start with atcr_device_, got %v", secret)
}
// Verify device was created
devices := store.ListDevices(tt.did)
if len(devices) != 1 {
t.Errorf("Expected 1 device, got %d", len(devices))
}
}
if tt.wantErr && tt.errString != "" && err != nil {
if !strings.Contains(err.Error(), tt.errString) {
t.Errorf("Error should contain %q, got %v", tt.errString, err)
}
}
})
}
}
// TestDeviceStore_ApprovePending_AlreadyApproved tests double approval
func TestDeviceStore_ApprovePending_AlreadyApproved(t *testing.T) {
store := setupTestDB(t)
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
// First approval
_, err = store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
if err != nil {
t.Fatalf("First ApprovePending() error = %v", err)
}
// Second approval should fail
_, err = store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
if err == nil {
t.Error("Expected error for double approval, got nil")
}
if !strings.Contains(err.Error(), "already approved") {
t.Errorf("Error should contain 'already approved', got %v", err)
}
}
// TestDeviceStore_ValidateDeviceSecret tests device secret validation
func TestDeviceStore_ValidateDeviceSecret(t *testing.T) {
store := setupTestDB(t)
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
// Create and approve a device
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
if err != nil {
t.Fatalf("ApprovePending() error = %v", err)
}
tests := []struct {
name string
secret string
wantErr bool
}{
{
name: "valid secret",
secret: secret,
wantErr: false,
},
{
name: "invalid secret",
secret: "atcr_device_invalid",
wantErr: true,
},
{
name: "empty secret",
secret: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
device, err := store.ValidateDeviceSecret(tt.secret)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateDeviceSecret() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if device.DID != "did:plc:alice123" {
t.Errorf("DID = %v, want did:plc:alice123", device.DID)
}
if device.Name != "My Device" {
t.Errorf("Name = %v, want My Device", device.Name)
}
}
})
}
}
// TestDeviceStore_ListDevices tests listing devices
func TestDeviceStore_ListDevices(t *testing.T) {
store := setupTestDB(t)
did := "did:plc:alice123"
createTestUser(t, store, did, "alice.bsky.social")
// Initially empty
devices := store.ListDevices(did)
if len(devices) != 0 {
t.Errorf("Expected 0 devices initially, got %d", len(devices))
}
// Create 3 devices
for i := range 3 {
pending, err := store.CreatePendingAuth("Device "+string(rune('A'+i)), "192.168.1.1", "Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
_, err = store.ApprovePending(pending.UserCode, did, "alice.bsky.social")
if err != nil {
t.Fatalf("ApprovePending() error = %v", err)
}
}
// List devices
devices = store.ListDevices(did)
if len(devices) != 3 {
t.Errorf("Expected 3 devices, got %d", len(devices))
}
// Verify they're sorted by created_at DESC (newest first)
for i := range len(devices) - 1 {
if devices[i].CreatedAt.Before(devices[i+1].CreatedAt) {
t.Error("Devices should be sorted by created_at DESC")
}
}
// List devices for different DID
otherDevices := store.ListDevices("did:plc:bob123")
if len(otherDevices) != 0 {
t.Errorf("Expected 0 devices for different DID, got %d", len(otherDevices))
}
}
// TestDeviceStore_RevokeDevice tests revoking a device
func TestDeviceStore_RevokeDevice(t *testing.T) {
store := setupTestDB(t)
did := "did:plc:alice123"
createTestUser(t, store, did, "alice.bsky.social")
// Create device
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
_, err = store.ApprovePending(pending.UserCode, did, "alice.bsky.social")
if err != nil {
t.Fatalf("ApprovePending() error = %v", err)
}
devices := store.ListDevices(did)
if len(devices) != 1 {
t.Fatalf("Expected 1 device, got %d", len(devices))
}
deviceID := devices[0].ID
tests := []struct {
name string
did string
deviceID string
wantErr bool
}{
{
name: "successful revocation",
did: did,
deviceID: deviceID,
wantErr: false,
},
{
name: "non-existent device",
did: did,
deviceID: "non-existent-id",
wantErr: true,
},
{
name: "wrong DID",
did: "did:plc:bob123",
deviceID: deviceID,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.RevokeDevice(tt.did, tt.deviceID)
if (err != nil) != tt.wantErr {
t.Errorf("RevokeDevice() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
// Verify device was removed (after first successful test)
devices = store.ListDevices(did)
if len(devices) != 0 {
t.Errorf("Expected 0 devices after revocation, got %d", len(devices))
}
}
// TestDeviceStore_UpdateLastUsed tests updating last used timestamp
func TestDeviceStore_UpdateLastUsed(t *testing.T) {
store := setupTestDB(t)
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
// Create device
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
if err != nil {
t.Fatalf("ApprovePending() error = %v", err)
}
// Get device to get secret hash
device, err := store.ValidateDeviceSecret(secret)
if err != nil {
t.Fatalf("ValidateDeviceSecret() error = %v", err)
}
initialLastUsed := device.LastUsed
// Wait a bit to ensure timestamp difference
time.Sleep(10 * time.Millisecond)
// Update last used
store.UpdateLastUsed(device.SecretHash)
// Verify it was updated
device2, err := store.ValidateDeviceSecret(secret)
if err != nil {
t.Fatalf("ValidateDeviceSecret() error = %v", err)
}
if !device2.LastUsed.After(initialLastUsed) {
t.Error("LastUsed should be updated to later time")
}
}
// TestDeviceStore_CleanupExpired tests cleanup of expired pending auths
func TestDeviceStore_CleanupExpired(t *testing.T) {
store := setupTestDB(t)
// Create pending auth with manual expiration time
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
// Manually update expiration to the past
_, err = store.db.Exec(`
UPDATE pending_device_auth
SET expires_at = datetime('now', '-1 hour')
WHERE device_code = ?
`, pending.DeviceCode)
if err != nil {
t.Fatalf("Failed to update expiration: %v", err)
}
// Run cleanup
store.CleanupExpired()
// Verify it was deleted
_, found := store.GetPendingByDeviceCode(pending.DeviceCode)
if found {
t.Error("Expired pending auth should have been cleaned up")
}
}
// TestDeviceStore_CleanupExpiredContext tests context-aware cleanup
func TestDeviceStore_CleanupExpiredContext(t *testing.T) {
store := setupTestDB(t)
// Create and expire pending auth
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
_, err = store.db.Exec(`
UPDATE pending_device_auth
SET expires_at = datetime('now', '-1 hour')
WHERE device_code = ?
`, pending.DeviceCode)
if err != nil {
t.Fatalf("Failed to update expiration: %v", err)
}
// Run context-aware cleanup
ctx := context.Background()
err = store.CleanupExpiredContext(ctx)
if err != nil {
t.Errorf("CleanupExpiredContext() error = %v", err)
}
// Verify it was deleted
_, found := store.GetPendingByDeviceCode(pending.DeviceCode)
if found {
t.Error("Expired pending auth should have been cleaned up")
}
}
// TestDeviceStore_SecretHashing tests bcrypt hashing
func TestDeviceStore_SecretHashing(t *testing.T) {
store := setupTestDB(t)
createTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
pending, err := store.CreatePendingAuth("My Device", "192.168.1.1", "Test Agent")
if err != nil {
t.Fatalf("CreatePendingAuth() error = %v", err)
}
secret, err := store.ApprovePending(pending.UserCode, "did:plc:alice123", "alice.bsky.social")
if err != nil {
t.Fatalf("ApprovePending() error = %v", err)
}
// Get device via ValidateDeviceSecret to access secret hash
device, err := store.ValidateDeviceSecret(secret)
if err != nil {
t.Fatalf("ValidateDeviceSecret() error = %v", err)
}
// Verify bcrypt hash is valid
err = bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte(secret))
if err != nil {
t.Error("Secret hash should match secret")
}
// Verify wrong secret doesn't match
err = bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte("wrong_secret"))
if err == nil {
t.Error("Wrong secret should not match hash")
}
}