Files
at-container-registry/cmd/db-migrate/main.go

375 lines
9.7 KiB
Go

// db-migrate copies all tables and data from a local SQLite database to a
// remote libsql database (e.g. Bunny Database, Turso). It reads the schema
// from sqlite_master, creates tables on the remote, and inserts all rows
// in batches. Generic — works with any SQLite DB (appview, hold, etc.).
//
// Usage:
//
// go run ./cmd/db-migrate --local /path/to/local.db --remote "libsql://..." --token "..."
// go run ./cmd/db-migrate --local /path/to/local.db --remote "libsql://..." --token "..." --skip-existing
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
_ "github.com/tursodatabase/go-libsql"
)
func main() {
localPath := flag.String("local", "", "Path to local SQLite database file")
remoteURL := flag.String("remote", "", "Remote libsql URL (libsql://...)")
authToken := flag.String("token", "", "Auth token for remote database")
skipExisting := flag.Bool("skip-existing", false, "Skip tables that already have data on remote")
batchSize := flag.Int("batch-size", 100, "Number of rows per INSERT batch")
dryRun := flag.Bool("dry-run", false, "Show what would be migrated without writing")
flag.Parse()
if *localPath == "" || *remoteURL == "" || *authToken == "" {
flag.Usage()
os.Exit(1)
}
// Open local database read-only
localDSN := *localPath
if !strings.HasPrefix(localDSN, "file:") {
localDSN = "file:" + localDSN
}
localDSN += "?mode=ro"
localDB, err := sql.Open("libsql", localDSN)
if err != nil {
log.Fatalf("Failed to open local database: %v", err)
}
defer localDB.Close()
if err := localDB.Ping(); err != nil {
log.Fatalf("Failed to ping local database: %v", err)
}
// Open remote database
remoteDSN := fmt.Sprintf("%s?authToken=%s", *remoteURL, *authToken)
remoteDB, err := sql.Open("libsql", remoteDSN)
if err != nil {
log.Fatalf("Failed to open remote database: %v", err)
}
defer remoteDB.Close()
if err := remoteDB.Ping(); err != nil {
log.Fatalf("Failed to ping remote database: %v", err)
}
// Get all user tables from local
tables, err := getTables(localDB)
if err != nil {
log.Fatalf("Failed to list tables: %v", err)
}
if len(tables) == 0 {
log.Println("No tables found in local database")
return
}
fmt.Printf("Found %d tables to migrate\n\n", len(tables))
start := time.Now()
if !*dryRun {
// Phase 1: Create all tables first so FK references resolve
fmt.Println("Creating tables...")
for _, t := range tables {
if err := createTable(remoteDB, t); err != nil {
log.Fatalf("Failed to create table %s: %v", t.name, err)
}
}
fmt.Println()
}
// Phase 2: Copy data
fmt.Println("Migrating data...")
totalRows := 0
for _, t := range tables {
count, err := migrateTable(localDB, remoteDB, t, *batchSize, *skipExisting, *dryRun)
if err != nil {
log.Fatalf("Failed to migrate table %s: %v", t.name, err)
}
totalRows += count
}
if !*dryRun {
// Phase 3: Create indexes after data is loaded (faster than indexing during insert)
fmt.Println("\nCreating indexes...")
for _, t := range tables {
if err := createIndexes(localDB, remoteDB, t.name); err != nil {
log.Fatalf("Failed to create indexes for %s: %v", t.name, err)
}
}
}
fmt.Printf("\nDone. %d total rows across %d tables in %s\n", totalRows, len(tables), time.Since(start).Round(time.Millisecond))
if *dryRun {
fmt.Println("(dry run — nothing was written)")
}
}
type tableInfo struct {
name string
ddl string
}
func getTables(db *sql.DB) ([]tableInfo, error) {
rows, err := db.Query(`
SELECT name, sql FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
AND name NOT LIKE '_litestream_%'
AND name NOT LIKE 'libsql_%'
ORDER BY name
`)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []tableInfo
for rows.Next() {
var t tableInfo
var ddl sql.NullString
if err := rows.Scan(&t.name, &ddl); err != nil {
return nil, err
}
if ddl.Valid {
t.ddl = ddl.String
}
tables = append(tables, t)
}
if err := rows.Err(); err != nil {
return nil, err
}
// Sort tables so those referenced by foreign keys come first.
// Tables with FK references depend on other tables existing and
// having data, so we insert referenced tables first.
return topoSortTables(db, tables)
}
// topoSortTables orders tables so that referenced (parent) tables come before
// tables that reference them via foreign keys.
func topoSortTables(db *sql.DB, tables []tableInfo) ([]tableInfo, error) {
byName := make(map[string]tableInfo, len(tables))
for _, t := range tables {
byName[t.name] = t
}
// Build dependency graph: table -> tables it references
deps := make(map[string][]string)
for _, t := range tables {
fkRows, err := db.Query(fmt.Sprintf("PRAGMA foreign_key_list([%s])", t.name))
if err != nil {
// PRAGMA might not return rows for tables without FKs
continue
}
seen := make(map[string]bool)
for fkRows.Next() {
var id, seq int
var table, from, to, onUpdate, onDelete, match string
if err := fkRows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil {
fkRows.Close()
return nil, err
}
if !seen[table] {
deps[t.name] = append(deps[t.name], table)
seen[table] = true
}
}
fkRows.Close()
}
// Topological sort (Kahn's algorithm)
visited := make(map[string]bool)
var sorted []tableInfo
var visit func(name string)
visit = func(name string) {
if visited[name] {
return
}
visited[name] = true
for _, dep := range deps[name] {
visit(dep)
}
if t, ok := byName[name]; ok {
sorted = append(sorted, t)
}
}
for _, t := range tables {
visit(t.name)
}
return sorted, nil
}
func getIndexes(db *sql.DB, tableName string) ([]string, error) {
rows, err := db.Query(`
SELECT sql FROM sqlite_master
WHERE type = 'index'
AND tbl_name = ?
AND sql IS NOT NULL
`, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
var indexes []string
for rows.Next() {
var ddl string
if err := rows.Scan(&ddl); err != nil {
return nil, err
}
indexes = append(indexes, ddl)
}
return indexes, rows.Err()
}
func createTable(remoteDB *sql.DB, t tableInfo) error {
if t.ddl == "" {
return nil
}
ddl := t.ddl
if !strings.Contains(strings.ToUpper(ddl), "IF NOT EXISTS") {
ddl = strings.Replace(ddl, "CREATE TABLE", "CREATE TABLE IF NOT EXISTS", 1)
}
if _, err := remoteDB.Exec(ddl); err != nil {
return fmt.Errorf("create table %s: %w", t.name, err)
}
fmt.Printf(" %s\n", t.name)
return nil
}
func createIndexes(localDB, remoteDB *sql.DB, tableName string) error {
indexes, err := getIndexes(localDB, tableName)
if err != nil {
return err
}
for _, idx := range indexes {
ddl := idx
if !strings.Contains(strings.ToUpper(ddl), "IF NOT EXISTS") {
ddl = strings.Replace(ddl, "CREATE INDEX", "CREATE INDEX IF NOT EXISTS", 1)
ddl = strings.Replace(ddl, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
}
if _, err := remoteDB.Exec(ddl); err != nil {
return fmt.Errorf("create index on %s: %w", tableName, err)
}
}
if len(indexes) > 0 {
fmt.Printf(" %s: %d indexes\n", tableName, len(indexes))
}
return nil
}
func migrateTable(localDB, remoteDB *sql.DB, t tableInfo, batchSize int, skipExisting, dryRun bool) (int, error) {
var localCount int
if err := localDB.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM [%s]", t.name)).Scan(&localCount); err != nil {
return 0, fmt.Errorf("count local rows: %w", err)
}
if localCount == 0 {
fmt.Printf(" %-30s %6d rows (empty)\n", t.name, 0)
return 0, nil
}
if dryRun {
fmt.Printf(" %-30s %6d rows (would migrate)\n", t.name, localCount)
return localCount, nil
}
if skipExisting {
var remoteCount int
if err := remoteDB.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM [%s]", t.name)).Scan(&remoteCount); err != nil {
return 0, fmt.Errorf("count remote rows: %w", err)
}
if remoteCount > 0 {
fmt.Printf(" %-30s %6d rows (skipped, %d on remote)\n", t.name, localCount, remoteCount)
return 0, nil
}
}
rows, err := localDB.Query(fmt.Sprintf("SELECT * FROM [%s]", t.name))
if err != nil {
return 0, fmt.Errorf("select: %w", err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("columns: %w", err)
}
placeholders := make([]string, len(cols))
quotedCols := make([]string, len(cols))
for i, c := range cols {
placeholders[i] = "?"
quotedCols[i] = fmt.Sprintf("[%s]", c)
}
insertPrefix := fmt.Sprintf("INSERT INTO [%s] (%s) VALUES ", t.name, strings.Join(quotedCols, ", "))
rowPlaceholder := "(" + strings.Join(placeholders, ", ") + ")"
inserted := 0
batch := make([][]any, 0, batchSize)
for rows.Next() {
vals := make([]any, len(cols))
ptrs := make([]any, len(cols))
for i := range vals {
ptrs[i] = &vals[i]
}
if err := rows.Scan(ptrs...); err != nil {
return 0, fmt.Errorf("scan: %w", err)
}
batch = append(batch, vals)
if len(batch) >= batchSize {
if err := insertBatch(remoteDB, insertPrefix, rowPlaceholder, batch); err != nil {
return 0, fmt.Errorf("insert batch at row %d: %w", inserted, err)
}
inserted += len(batch)
batch = batch[:0]
}
}
if len(batch) > 0 {
if err := insertBatch(remoteDB, insertPrefix, rowPlaceholder, batch); err != nil {
return 0, fmt.Errorf("insert final batch: %w", err)
}
inserted += len(batch)
}
if err := rows.Err(); err != nil {
return 0, fmt.Errorf("rows iteration: %w", err)
}
fmt.Printf(" %-30s %6d rows migrated\n", t.name, inserted)
return inserted, nil
}
func insertBatch(db *sql.DB, prefix, rowPlaceholder string, batch [][]any) error {
if len(batch) == 0 {
return nil
}
placeholders := make([]string, len(batch))
var args []any
for i, row := range batch {
placeholders[i] = rowPlaceholder
args = append(args, row...)
}
query := prefix + strings.Join(placeholders, ", ")
_, err := db.Exec(query, args...)
return err
}