Files
at-container-registry/pkg/appview/db/device_store.go
2026-01-04 22:02:01 -06:00

435 lines
11 KiB
Go

package db
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
// Device represents an authorized device
type Device struct {
ID string `json:"id"`
DID string `json:"did"`
Handle string `json:"handle"`
Name string `json:"name"`
SecretHash string `json:"secret_hash"`
IPAddress string `json:"ip_address"`
Location string `json:"location"`
UserAgent string `json:"user_agent"`
CreatedAt time.Time `json:"created_at"`
LastUsed time.Time `json:"last_used"`
}
// PendingAuthorization represents a device awaiting user approval
type PendingAuthorization struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
DeviceName string `json:"device_name"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
ExpiresAt time.Time `json:"expires_at"`
ApprovedDID *string `json:"approved_did"`
ApprovedAt *time.Time `json:"approved_at"`
DeviceSecret *string `json:"device_secret"`
}
// DeviceStore manages devices and pending authorizations with SQLite persistence
type DeviceStore struct {
db *sql.DB
}
// NewDeviceStore creates a new SQLite-backed device store
func NewDeviceStore(db *sql.DB) *DeviceStore {
return &DeviceStore{db: db}
}
// CreatePendingAuth creates a new pending device authorization
func (s *DeviceStore) CreatePendingAuth(deviceName, ip, userAgent string) (*PendingAuthorization, error) {
// Generate device code (long, random)
deviceCodeBytes := make([]byte, 32)
if _, err := rand.Read(deviceCodeBytes); err != nil {
return nil, fmt.Errorf("failed to generate device code: %w", err)
}
deviceCode := base64.RawURLEncoding.EncodeToString(deviceCodeBytes)
// Generate user code (short, human-readable)
userCode := generateUserCode()
expiresAt := time.Now().Add(10 * time.Minute)
_, err := s.db.Exec(`
INSERT INTO pending_device_auth (device_code, user_code, device_name, ip_address, user_agent, expires_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
`, deviceCode, userCode, deviceName, ip, userAgent, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to create pending auth: %w", err)
}
pending := &PendingAuthorization{
DeviceCode: deviceCode,
UserCode: userCode,
DeviceName: deviceName,
IPAddress: ip,
UserAgent: userAgent,
ExpiresAt: expiresAt,
}
return pending, nil
}
// GetPendingByUserCode retrieves a pending auth by user code
func (s *DeviceStore) GetPendingByUserCode(userCode string) (*PendingAuthorization, bool) {
var pending PendingAuthorization
err := s.db.QueryRow(`
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did, approved_at, device_secret
FROM pending_device_auth
WHERE user_code = ?
`, userCode).Scan(
&pending.DeviceCode,
&pending.UserCode,
&pending.DeviceName,
&pending.IPAddress,
&pending.UserAgent,
&pending.ExpiresAt,
&pending.ApprovedDID,
&pending.ApprovedAt,
&pending.DeviceSecret,
)
if err == sql.ErrNoRows {
return nil, false
}
if err != nil {
slog.Warn("Failed to query pending auth", "component", "device_store", "error", err)
return nil, false
}
// Check if expired
if time.Now().After(pending.ExpiresAt) {
return nil, false
}
return &pending, true
}
// GetPendingByDeviceCode retrieves a pending auth by device code
func (s *DeviceStore) GetPendingByDeviceCode(deviceCode string) (*PendingAuthorization, bool) {
var pending PendingAuthorization
err := s.db.QueryRow(`
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did, approved_at, device_secret
FROM pending_device_auth
WHERE device_code = ?
`, deviceCode).Scan(
&pending.DeviceCode,
&pending.UserCode,
&pending.DeviceName,
&pending.IPAddress,
&pending.UserAgent,
&pending.ExpiresAt,
&pending.ApprovedDID,
&pending.ApprovedAt,
&pending.DeviceSecret,
)
if err == sql.ErrNoRows {
return nil, false
}
if err != nil {
slog.Warn("Failed to query pending auth", "component", "device_store", "error", err)
return nil, false
}
// Check if expired
if time.Now().After(pending.ExpiresAt) {
return nil, false
}
return &pending, true
}
// ApprovePending approves a pending authorization and generates device secret
func (s *DeviceStore) ApprovePending(userCode, did, handle string) (deviceSecret string, err error) {
// Start transaction
tx, err := s.db.Begin()
if err != nil {
return "", fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback()
// Get pending auth
var pending PendingAuthorization
err = tx.QueryRow(`
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did
FROM pending_device_auth
WHERE user_code = ?
`, userCode).Scan(
&pending.DeviceCode,
&pending.UserCode,
&pending.DeviceName,
&pending.IPAddress,
&pending.UserAgent,
&pending.ExpiresAt,
&pending.ApprovedDID,
)
if err == sql.ErrNoRows {
return "", fmt.Errorf("pending authorization not found")
}
if err != nil {
return "", fmt.Errorf("failed to query pending auth: %w", err)
}
// Check expiration
if time.Now().After(pending.ExpiresAt) {
return "", fmt.Errorf("authorization expired")
}
// Check if already approved
if pending.ApprovedDID != nil && *pending.ApprovedDID != "" {
return "", fmt.Errorf("already approved")
}
// Generate device secret
secretBytes := make([]byte, 32)
if _, err := rand.Read(secretBytes); err != nil {
return "", fmt.Errorf("failed to generate device secret: %w", err)
}
deviceSecret = "atcr_device_" + base64.RawURLEncoding.EncodeToString(secretBytes)
// Hash for storage
secretHashBytes, err := bcrypt.GenerateFromPassword([]byte(deviceSecret), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash device secret: %w", err)
}
secretHash := string(secretHashBytes)
// Create device record
deviceID := uuid.New().String()
now := time.Now()
_, err = tx.Exec(`
INSERT INTO devices (id, did, handle, name, secret_hash, ip_address, user_agent, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`, deviceID, did, handle, pending.DeviceName, secretHash, pending.IPAddress, pending.UserAgent, now)
if err != nil {
return "", fmt.Errorf("failed to create device: %w", err)
}
// Update pending auth to mark as approved
_, err = tx.Exec(`
UPDATE pending_device_auth
SET approved_did = ?, approved_at = ?, device_secret = ?
WHERE user_code = ?
`, did, now, deviceSecret, userCode)
if err != nil {
return "", fmt.Errorf("failed to update pending auth: %w", err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return "", fmt.Errorf("failed to commit transaction: %w", err)
}
return deviceSecret, nil
}
// ValidateDeviceSecret validates a device secret and returns the device
func (s *DeviceStore) ValidateDeviceSecret(secret string) (*Device, error) {
// Query all devices and check bcrypt hash
rows, err := s.db.Query(`
SELECT id, did, handle, name, secret_hash, ip_address, location, user_agent, created_at, last_used
FROM devices
`)
if err != nil {
return nil, fmt.Errorf("failed to query devices: %w", err)
}
defer rows.Close()
for rows.Next() {
var device Device
var lastUsed sql.NullTime
var location sql.NullString
err := rows.Scan(
&device.ID,
&device.DID,
&device.Handle,
&device.Name,
&device.SecretHash,
&device.IPAddress,
&location,
&device.UserAgent,
&device.CreatedAt,
&lastUsed,
)
if err != nil {
continue
}
if lastUsed.Valid {
device.LastUsed = lastUsed.Time
}
if location.Valid {
device.Location = location.String
}
// Check if this device's hash matches the secret
if err := bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte(secret)); err == nil {
// Update last used asynchronously
go s.UpdateLastUsed(device.SecretHash)
return &device, nil
}
}
return nil, fmt.Errorf("invalid device secret")
}
// ListDevices returns all devices for a DID
func (s *DeviceStore) ListDevices(did string) []*Device {
rows, err := s.db.Query(`
SELECT id, did, handle, name, ip_address, location, user_agent, created_at, last_used
FROM devices
WHERE did = ?
ORDER BY created_at DESC
`, did)
if err != nil {
return []*Device{}
}
defer rows.Close()
var devices []*Device
for rows.Next() {
var device Device
var lastUsed sql.NullTime
var location sql.NullString
err := rows.Scan(
&device.ID,
&device.DID,
&device.Handle,
&device.Name,
&device.IPAddress,
&location,
&device.UserAgent,
&device.CreatedAt,
&lastUsed,
)
if err != nil {
continue
}
if lastUsed.Valid {
device.LastUsed = lastUsed.Time
}
if location.Valid {
device.Location = location.String
}
devices = append(devices, &device)
}
return devices
}
// RevokeDevice removes a device
func (s *DeviceStore) RevokeDevice(did, deviceID string) error {
result, err := s.db.Exec(`
DELETE FROM devices
WHERE did = ? AND id = ?
`, did, deviceID)
if err != nil {
return fmt.Errorf("failed to revoke device: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("device not found")
}
return nil
}
// UpdateLastUsed updates the last used timestamp
func (s *DeviceStore) UpdateLastUsed(secretHash string) {
_, err := s.db.Exec(`
UPDATE devices
SET last_used = ?
WHERE secret_hash = ?
`, time.Now(), secretHash)
if err != nil {
slog.Warn("Failed to update device last used timestamp", "component", "device_store", "error", err)
}
}
// CleanupExpired removes expired pending authorizations
func (s *DeviceStore) CleanupExpired() {
result, err := s.db.Exec(`
DELETE FROM pending_device_auth
WHERE expires_at < datetime('now')
`)
if err != nil {
slog.Warn("Failed to cleanup expired pending auths", "component", "device_store", "error", err)
return
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Cleaned up expired pending device auths", "count", deleted)
}
}
// CleanupExpiredContext is a context-aware version for background workers
func (s *DeviceStore) CleanupExpiredContext(ctx context.Context) error {
result, err := s.db.ExecContext(ctx, `
DELETE FROM pending_device_auth
WHERE expires_at < datetime('now')
`)
if err != nil {
return fmt.Errorf("failed to cleanup expired pending auths: %w", err)
}
deleted, _ := result.RowsAffected()
if deleted > 0 {
slog.Info("Cleaned up expired pending device auths", "count", deleted)
}
return nil
}
// generateUserCode creates a short, human-readable code
// Format: XXXX-XXXX (e.g., "WDJB-MJHT")
// Character set: A-Z excluding ambiguous chars (0, O, I, 1, L)
func generateUserCode() string {
chars := "ABCDEFGHJKMNPQRSTUVWXYZ23456789"
code := make([]byte, 8)
if _, err := rand.Read(code); err != nil {
// Fallback to timestamp-based generation if crypto rand fails
now := time.Now().UnixNano()
for i := range code {
code[i] = byte(now >> (i * 8))
}
}
for i := range code {
code[i] = chars[int(code[i])%len(chars)]
}
return string(code[:4]) + "-" + string(code[4:])
}