Files
2025-10-28 20:39:57 -05:00

167 lines
4.4 KiB
Go

// Package db provides the database layer for the AppView web UI, including
// SQLite schema initialization, migrations, and query functions for OAuth
// sessions, device flows, repository metadata, stars, pull counts, and
// user profiles.
package db
import (
"database/sql"
"embed"
"fmt"
"io/fs"
"log/slog"
"path/filepath"
"sort"
"strconv"
"strings"
_ "github.com/mattn/go-sqlite3"
"go.yaml.in/yaml/v4"
)
//go:embed migrations/*.yaml
var migrationsFS embed.FS
//go:embed schema.sql
var schemaSQL string
// InitDB initializes the SQLite database with the schema
func InitDB(path string, skipMigrations bool) (*sql.DB, error) {
db, err := sql.Open("sqlite3", path)
if err != nil {
return nil, err
}
// Enable foreign keys
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
}
// Create schema from embedded SQL file
if _, err := db.Exec(schemaSQL); err != nil {
return nil, err
}
// Run migrations unless skipped
if !skipMigrations {
if err := runMigrations(db); err != nil {
return nil, err
}
}
return db, nil
}
// Migration represents a database migration
type Migration struct {
Version int
Name string
Description string `yaml:"description"`
Query string `yaml:"query"`
}
// runMigrations applies any pending database migrations
func runMigrations(db *sql.DB) error {
// Load migrations from files
migrations, err := loadMigrations()
if err != nil {
return fmt.Errorf("failed to load migrations: %w", err)
}
// Sort migrations by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
for _, m := range migrations {
// Check if migration already applied
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.Version).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
// Migration already applied
continue
}
// Apply migration
slog.Info("Applying migration", "version", m.Version, "name", m.Name, "description", m.Description)
if _, err := db.Exec(m.Query); err != nil {
return fmt.Errorf("failed to apply migration %d (%s): %w", m.Version, m.Name, err)
}
// Record migration
if _, err := db.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.Version); err != nil {
return fmt.Errorf("failed to record migration %d: %w", m.Version, err)
}
slog.Info("Migration applied successfully", "version", m.Version)
}
return nil
}
// loadMigrations loads all migration files from embedded filesystem
func loadMigrations() ([]Migration, error) {
// Read all migration files from embedded FS
entries, err := fs.Glob(migrationsFS, "migrations/[0-9][0-9][0-9][0-9]_*.yaml")
if err != nil {
return nil, fmt.Errorf("failed to list migration files: %w", err)
}
var migrations []Migration
for _, file := range entries {
// Parse version and name from filename
basename := filepath.Base(file)
version, name, err := parseMigrationFilename(basename)
if err != nil {
return nil, fmt.Errorf("invalid migration filename %s: %w", basename, err)
}
// Read YAML content from embedded FS
data, err := migrationsFS.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
}
var m Migration
if err := yaml.Unmarshal(data, &m); err != nil {
return nil, fmt.Errorf("failed to parse migration file %s: %w", file, err)
}
// Set version and name from filename
m.Version = version
m.Name = name
// Validate migration
if m.Query == "" {
return nil, fmt.Errorf("missing migration 'query' in %s", file)
}
migrations = append(migrations, m)
}
return migrations, nil
}
// parseMigrationFilename extracts version and name from migration filename
// Expected format: 0001_migration_name.yaml
// Returns: version (int), name (string), error
// Note: Glob pattern ensures format is valid, so minimal validation needed
func parseMigrationFilename(filename string) (int, string, error) {
// Remove extension (.yaml or .yml)
ext := filepath.Ext(filename)
fileNameWithoutExt := filename[:len(filename)-len(ext)]
// First 4 characters are the version (glob guarantees they're digits)
version, _ := strconv.Atoi(fileNameWithoutExt[:4])
// Remainder after position 5 is the name (glob guarantees it exists)
name := strings.ReplaceAll(fileNameWithoutExt[5:], "_", " ")
name = strings.TrimSpace(name)
return version, name, nil
}