mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 16:40:29 +00:00
435 lines
11 KiB
Go
435 lines
11 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// Device represents an authorized device
|
|
type Device struct {
|
|
ID string `json:"id"`
|
|
DID string `json:"did"`
|
|
Handle string `json:"handle"`
|
|
Name string `json:"name"`
|
|
SecretHash string `json:"secret_hash"`
|
|
IPAddress string `json:"ip_address"`
|
|
Location string `json:"location"`
|
|
UserAgent string `json:"user_agent"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
LastUsed time.Time `json:"last_used"`
|
|
}
|
|
|
|
// PendingAuthorization represents a device awaiting user approval
|
|
type PendingAuthorization struct {
|
|
DeviceCode string `json:"device_code"`
|
|
UserCode string `json:"user_code"`
|
|
DeviceName string `json:"device_name"`
|
|
IPAddress string `json:"ip_address"`
|
|
UserAgent string `json:"user_agent"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
ApprovedDID *string `json:"approved_did"`
|
|
ApprovedAt *time.Time `json:"approved_at"`
|
|
DeviceSecret *string `json:"device_secret"`
|
|
}
|
|
|
|
// DeviceStore manages devices and pending authorizations with SQLite persistence
|
|
type DeviceStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewDeviceStore creates a new SQLite-backed device store
|
|
func NewDeviceStore(db *sql.DB) *DeviceStore {
|
|
return &DeviceStore{db: db}
|
|
}
|
|
|
|
// CreatePendingAuth creates a new pending device authorization
|
|
func (s *DeviceStore) CreatePendingAuth(deviceName, ip, userAgent string) (*PendingAuthorization, error) {
|
|
// Generate device code (long, random)
|
|
deviceCodeBytes := make([]byte, 32)
|
|
if _, err := rand.Read(deviceCodeBytes); err != nil {
|
|
return nil, fmt.Errorf("failed to generate device code: %w", err)
|
|
}
|
|
deviceCode := base64.RawURLEncoding.EncodeToString(deviceCodeBytes)
|
|
|
|
// Generate user code (short, human-readable)
|
|
userCode := generateUserCode()
|
|
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
|
|
_, err := s.db.Exec(`
|
|
INSERT INTO pending_device_auth (device_code, user_code, device_name, ip_address, user_agent, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, datetime('now'))
|
|
`, deviceCode, userCode, deviceName, ip, userAgent, expiresAt)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create pending auth: %w", err)
|
|
}
|
|
|
|
pending := &PendingAuthorization{
|
|
DeviceCode: deviceCode,
|
|
UserCode: userCode,
|
|
DeviceName: deviceName,
|
|
IPAddress: ip,
|
|
UserAgent: userAgent,
|
|
ExpiresAt: expiresAt,
|
|
}
|
|
|
|
return pending, nil
|
|
}
|
|
|
|
// GetPendingByUserCode retrieves a pending auth by user code
|
|
func (s *DeviceStore) GetPendingByUserCode(userCode string) (*PendingAuthorization, bool) {
|
|
var pending PendingAuthorization
|
|
|
|
err := s.db.QueryRow(`
|
|
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did, approved_at, device_secret
|
|
FROM pending_device_auth
|
|
WHERE user_code = ?
|
|
`, userCode).Scan(
|
|
&pending.DeviceCode,
|
|
&pending.UserCode,
|
|
&pending.DeviceName,
|
|
&pending.IPAddress,
|
|
&pending.UserAgent,
|
|
&pending.ExpiresAt,
|
|
&pending.ApprovedDID,
|
|
&pending.ApprovedAt,
|
|
&pending.DeviceSecret,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, false
|
|
}
|
|
if err != nil {
|
|
slog.Warn("Failed to query pending auth", "component", "device_store", "error", err)
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(pending.ExpiresAt) {
|
|
return nil, false
|
|
}
|
|
|
|
return &pending, true
|
|
}
|
|
|
|
// GetPendingByDeviceCode retrieves a pending auth by device code
|
|
func (s *DeviceStore) GetPendingByDeviceCode(deviceCode string) (*PendingAuthorization, bool) {
|
|
var pending PendingAuthorization
|
|
|
|
err := s.db.QueryRow(`
|
|
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did, approved_at, device_secret
|
|
FROM pending_device_auth
|
|
WHERE device_code = ?
|
|
`, deviceCode).Scan(
|
|
&pending.DeviceCode,
|
|
&pending.UserCode,
|
|
&pending.DeviceName,
|
|
&pending.IPAddress,
|
|
&pending.UserAgent,
|
|
&pending.ExpiresAt,
|
|
&pending.ApprovedDID,
|
|
&pending.ApprovedAt,
|
|
&pending.DeviceSecret,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, false
|
|
}
|
|
if err != nil {
|
|
slog.Warn("Failed to query pending auth", "component", "device_store", "error", err)
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(pending.ExpiresAt) {
|
|
return nil, false
|
|
}
|
|
|
|
return &pending, true
|
|
}
|
|
|
|
// ApprovePending approves a pending authorization and generates device secret
|
|
func (s *DeviceStore) ApprovePending(userCode, did, handle string) (deviceSecret string, err error) {
|
|
// Start transaction
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Get pending auth
|
|
var pending PendingAuthorization
|
|
err = tx.QueryRow(`
|
|
SELECT device_code, user_code, device_name, ip_address, user_agent, expires_at, approved_did
|
|
FROM pending_device_auth
|
|
WHERE user_code = ?
|
|
`, userCode).Scan(
|
|
&pending.DeviceCode,
|
|
&pending.UserCode,
|
|
&pending.DeviceName,
|
|
&pending.IPAddress,
|
|
&pending.UserAgent,
|
|
&pending.ExpiresAt,
|
|
&pending.ApprovedDID,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return "", fmt.Errorf("pending authorization not found")
|
|
}
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to query pending auth: %w", err)
|
|
}
|
|
|
|
// Check expiration
|
|
if time.Now().After(pending.ExpiresAt) {
|
|
return "", fmt.Errorf("authorization expired")
|
|
}
|
|
|
|
// Check if already approved
|
|
if pending.ApprovedDID != nil && *pending.ApprovedDID != "" {
|
|
return "", fmt.Errorf("already approved")
|
|
}
|
|
|
|
// Generate device secret
|
|
secretBytes := make([]byte, 32)
|
|
if _, err := rand.Read(secretBytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate device secret: %w", err)
|
|
}
|
|
deviceSecret = "atcr_device_" + base64.RawURLEncoding.EncodeToString(secretBytes)
|
|
|
|
// Hash for storage
|
|
secretHashBytes, err := bcrypt.GenerateFromPassword([]byte(deviceSecret), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to hash device secret: %w", err)
|
|
}
|
|
secretHash := string(secretHashBytes)
|
|
|
|
// Create device record
|
|
deviceID := uuid.New().String()
|
|
now := time.Now()
|
|
|
|
_, err = tx.Exec(`
|
|
INSERT INTO devices (id, did, handle, name, secret_hash, ip_address, user_agent, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
`, deviceID, did, handle, pending.DeviceName, secretHash, pending.IPAddress, pending.UserAgent, now)
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create device: %w", err)
|
|
}
|
|
|
|
// Update pending auth to mark as approved
|
|
_, err = tx.Exec(`
|
|
UPDATE pending_device_auth
|
|
SET approved_did = ?, approved_at = ?, device_secret = ?
|
|
WHERE user_code = ?
|
|
`, did, now, deviceSecret, userCode)
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to update pending auth: %w", err)
|
|
}
|
|
|
|
// Commit transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return "", fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return deviceSecret, nil
|
|
}
|
|
|
|
// ValidateDeviceSecret validates a device secret and returns the device
|
|
func (s *DeviceStore) ValidateDeviceSecret(secret string) (*Device, error) {
|
|
// Query all devices and check bcrypt hash
|
|
rows, err := s.db.Query(`
|
|
SELECT id, did, handle, name, secret_hash, ip_address, location, user_agent, created_at, last_used
|
|
FROM devices
|
|
`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query devices: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var device Device
|
|
var lastUsed sql.NullTime
|
|
var location sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&device.ID,
|
|
&device.DID,
|
|
&device.Handle,
|
|
&device.Name,
|
|
&device.SecretHash,
|
|
&device.IPAddress,
|
|
&location,
|
|
&device.UserAgent,
|
|
&device.CreatedAt,
|
|
&lastUsed,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if lastUsed.Valid {
|
|
device.LastUsed = lastUsed.Time
|
|
}
|
|
if location.Valid {
|
|
device.Location = location.String
|
|
}
|
|
|
|
// Check if this device's hash matches the secret
|
|
if err := bcrypt.CompareHashAndPassword([]byte(device.SecretHash), []byte(secret)); err == nil {
|
|
// Update last used asynchronously
|
|
go s.UpdateLastUsed(device.SecretHash)
|
|
|
|
return &device, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid device secret")
|
|
}
|
|
|
|
// ListDevices returns all devices for a DID
|
|
func (s *DeviceStore) ListDevices(did string) []*Device {
|
|
rows, err := s.db.Query(`
|
|
SELECT id, did, handle, name, ip_address, location, user_agent, created_at, last_used
|
|
FROM devices
|
|
WHERE did = ?
|
|
ORDER BY created_at DESC
|
|
`, did)
|
|
|
|
if err != nil {
|
|
return []*Device{}
|
|
}
|
|
defer rows.Close()
|
|
|
|
var devices []*Device
|
|
for rows.Next() {
|
|
var device Device
|
|
var lastUsed sql.NullTime
|
|
var location sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&device.ID,
|
|
&device.DID,
|
|
&device.Handle,
|
|
&device.Name,
|
|
&device.IPAddress,
|
|
&location,
|
|
&device.UserAgent,
|
|
&device.CreatedAt,
|
|
&lastUsed,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if lastUsed.Valid {
|
|
device.LastUsed = lastUsed.Time
|
|
}
|
|
if location.Valid {
|
|
device.Location = location.String
|
|
}
|
|
|
|
devices = append(devices, &device)
|
|
}
|
|
|
|
return devices
|
|
}
|
|
|
|
// RevokeDevice removes a device
|
|
func (s *DeviceStore) RevokeDevice(did, deviceID string) error {
|
|
result, err := s.db.Exec(`
|
|
DELETE FROM devices
|
|
WHERE did = ? AND id = ?
|
|
`, did, deviceID)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke device: %w", err)
|
|
}
|
|
|
|
rows, _ := result.RowsAffected()
|
|
if rows == 0 {
|
|
return fmt.Errorf("device not found")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateLastUsed updates the last used timestamp
|
|
func (s *DeviceStore) UpdateLastUsed(secretHash string) {
|
|
_, err := s.db.Exec(`
|
|
UPDATE devices
|
|
SET last_used = ?
|
|
WHERE secret_hash = ?
|
|
`, time.Now(), secretHash)
|
|
|
|
if err != nil {
|
|
slog.Warn("Failed to update device last used timestamp", "component", "device_store", "error", err)
|
|
}
|
|
}
|
|
|
|
// CleanupExpired removes expired pending authorizations
|
|
func (s *DeviceStore) CleanupExpired() {
|
|
result, err := s.db.Exec(`
|
|
DELETE FROM pending_device_auth
|
|
WHERE expires_at < datetime('now')
|
|
`)
|
|
|
|
if err != nil {
|
|
slog.Warn("Failed to cleanup expired pending auths", "component", "device_store", "error", err)
|
|
return
|
|
}
|
|
|
|
deleted, _ := result.RowsAffected()
|
|
if deleted > 0 {
|
|
slog.Info("Cleaned up expired pending device auths", "count", deleted)
|
|
}
|
|
}
|
|
|
|
// CleanupExpiredContext is a context-aware version for background workers
|
|
func (s *DeviceStore) CleanupExpiredContext(ctx context.Context) error {
|
|
result, err := s.db.ExecContext(ctx, `
|
|
DELETE FROM pending_device_auth
|
|
WHERE expires_at < datetime('now')
|
|
`)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to cleanup expired pending auths: %w", err)
|
|
}
|
|
|
|
deleted, _ := result.RowsAffected()
|
|
if deleted > 0 {
|
|
slog.Info("Cleaned up expired pending device auths", "count", deleted)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateUserCode creates a short, human-readable code
|
|
// Format: XXXX-XXXX (e.g., "WDJB-MJHT")
|
|
// Character set: A-Z excluding ambiguous chars (0, O, I, 1, L)
|
|
func generateUserCode() string {
|
|
chars := "ABCDEFGHJKMNPQRSTUVWXYZ23456789"
|
|
code := make([]byte, 8)
|
|
if _, err := rand.Read(code); err != nil {
|
|
// Fallback to timestamp-based generation if crypto rand fails
|
|
now := time.Now().UnixNano()
|
|
for i := range code {
|
|
code[i] = byte(now >> (i * 8))
|
|
}
|
|
}
|
|
for i := range code {
|
|
code[i] = chars[int(code[i])%len(chars)]
|
|
}
|
|
return string(code[:4]) + "-" + string(code[4:])
|
|
}
|