220 lines
5.8 KiB
Go
220 lines
5.8 KiB
Go
// Package db provides the database layer for the AppView web UI, including
|
|
// SQLite schema initialization, migrations, and query functions for OAuth
|
|
// sessions, device flows, repository metadata, stars, pull counts, and
|
|
// user profiles.
|
|
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"io/fs"
|
|
"log/slog"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"go.yaml.in/yaml/v4"
|
|
)
|
|
|
|
//go:embed migrations/*.yaml
|
|
var migrationsFS embed.FS
|
|
|
|
//go:embed schema.sql
|
|
var schemaSQL string
|
|
|
|
// InitDB initializes the SQLite database with the schema
|
|
func InitDB(path string, skipMigrations bool) (*sql.DB, error) {
|
|
db, err := sql.Open("sqlite3", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Enable foreign keys
|
|
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create schema from embedded SQL file
|
|
if _, err := db.Exec(schemaSQL); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Run migrations unless skipped
|
|
if !skipMigrations {
|
|
if err := runMigrations(db); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version int
|
|
Name string
|
|
Description string `yaml:"description"`
|
|
Query string `yaml:"query"`
|
|
}
|
|
|
|
// runMigrations applies any pending database migrations
|
|
func runMigrations(db *sql.DB) error {
|
|
// Load migrations from files
|
|
migrations, err := loadMigrations()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load migrations: %w", err)
|
|
}
|
|
|
|
// Sort migrations by version
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
|
|
for _, m := range migrations {
|
|
// Check if migration already applied
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.Version).Scan(&count)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check migration status: %w", err)
|
|
}
|
|
|
|
if count > 0 {
|
|
// Migration already applied
|
|
continue
|
|
}
|
|
|
|
// Apply migration in a transaction
|
|
slog.Info("Applying migration", "version", m.Version, "name", m.Name, "description", m.Description)
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction for migration %d: %w", m.Version, err)
|
|
}
|
|
|
|
// Split query into individual statements and execute each
|
|
// go-sqlite3's Exec() doesn't reliably execute all statements in multi-statement queries
|
|
statements := splitSQLStatements(m.Query)
|
|
for i, stmt := range statements {
|
|
if _, err := tx.Exec(stmt); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to apply migration %d (%s) statement %d: %w", m.Version, m.Name, i+1, err)
|
|
}
|
|
}
|
|
|
|
// Record migration
|
|
if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.Version); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to record migration %d: %w", m.Version, err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit migration %d: %w", m.Version, err)
|
|
}
|
|
|
|
slog.Info("Migration applied successfully", "version", m.Version)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadMigrations loads all migration files from embedded filesystem
|
|
func loadMigrations() ([]Migration, error) {
|
|
// Read all migration files from embedded FS
|
|
entries, err := fs.Glob(migrationsFS, "migrations/[0-9][0-9][0-9][0-9]_*.yaml")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list migration files: %w", err)
|
|
}
|
|
|
|
var migrations []Migration
|
|
for _, file := range entries {
|
|
// Parse version and name from filename
|
|
basename := filepath.Base(file)
|
|
version, name, err := parseMigrationFilename(basename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid migration filename %s: %w", basename, err)
|
|
}
|
|
|
|
// Read YAML content from embedded FS
|
|
data, err := migrationsFS.ReadFile(file)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read migration file %s: %w", file, err)
|
|
}
|
|
|
|
var m Migration
|
|
if err := yaml.Unmarshal(data, &m); err != nil {
|
|
return nil, fmt.Errorf("failed to parse migration file %s: %w", file, err)
|
|
}
|
|
|
|
// Set version and name from filename
|
|
m.Version = version
|
|
m.Name = name
|
|
|
|
// Validate migration
|
|
if m.Query == "" {
|
|
return nil, fmt.Errorf("missing migration 'query' in %s", file)
|
|
}
|
|
|
|
migrations = append(migrations, m)
|
|
}
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// splitSQLStatements splits a SQL query into individual statements.
|
|
// It handles semicolons as statement separators and filters out empty statements.
|
|
func splitSQLStatements(query string) []string {
|
|
var statements []string
|
|
|
|
// Split on semicolons
|
|
parts := strings.Split(query, ";")
|
|
|
|
for _, part := range parts {
|
|
// Trim whitespace
|
|
stmt := strings.TrimSpace(part)
|
|
|
|
// Skip empty statements (could be trailing semicolon or comment-only)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
|
|
// Skip comment-only statements
|
|
lines := strings.Split(stmt, "\n")
|
|
hasCode := false
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed != "" && !strings.HasPrefix(trimmed, "--") {
|
|
hasCode = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasCode {
|
|
statements = append(statements, stmt)
|
|
}
|
|
}
|
|
|
|
return statements
|
|
}
|
|
|
|
// parseMigrationFilename extracts version and name from migration filename
|
|
// Expected format: 0001_migration_name.yaml
|
|
// Returns: version (int), name (string), error
|
|
// Note: Glob pattern ensures format is valid, so minimal validation needed
|
|
func parseMigrationFilename(filename string) (int, string, error) {
|
|
// Remove extension (.yaml or .yml)
|
|
ext := filepath.Ext(filename)
|
|
fileNameWithoutExt := filename[:len(filename)-len(ext)]
|
|
|
|
// First 4 characters are the version (glob guarantees they're digits)
|
|
version, _ := strconv.Atoi(fileNameWithoutExt[:4])
|
|
|
|
// Remainder after position 5 is the name (glob guarantees it exists)
|
|
name := strings.ReplaceAll(fileNameWithoutExt[5:], "_", " ")
|
|
name = strings.TrimSpace(name)
|
|
|
|
return version, name, nil
|
|
}
|