mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 16:40:29 +00:00
375 lines
9.7 KiB
Go
375 lines
9.7 KiB
Go
// db-migrate copies all tables and data from a local SQLite database to a
|
|
// remote libsql database (e.g. Bunny Database, Turso). It reads the schema
|
|
// from sqlite_master, creates tables on the remote, and inserts all rows
|
|
// in batches. Generic — works with any SQLite DB (appview, hold, etc.).
|
|
//
|
|
// Usage:
|
|
//
|
|
// go run ./cmd/db-migrate --local /path/to/local.db --remote "libsql://..." --token "..."
|
|
// go run ./cmd/db-migrate --local /path/to/local.db --remote "libsql://..." --token "..." --skip-existing
|
|
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/tursodatabase/go-libsql"
|
|
)
|
|
|
|
func main() {
|
|
localPath := flag.String("local", "", "Path to local SQLite database file")
|
|
remoteURL := flag.String("remote", "", "Remote libsql URL (libsql://...)")
|
|
authToken := flag.String("token", "", "Auth token for remote database")
|
|
skipExisting := flag.Bool("skip-existing", false, "Skip tables that already have data on remote")
|
|
batchSize := flag.Int("batch-size", 100, "Number of rows per INSERT batch")
|
|
dryRun := flag.Bool("dry-run", false, "Show what would be migrated without writing")
|
|
flag.Parse()
|
|
|
|
if *localPath == "" || *remoteURL == "" || *authToken == "" {
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Open local database read-only
|
|
localDSN := *localPath
|
|
if !strings.HasPrefix(localDSN, "file:") {
|
|
localDSN = "file:" + localDSN
|
|
}
|
|
localDSN += "?mode=ro"
|
|
|
|
localDB, err := sql.Open("libsql", localDSN)
|
|
if err != nil {
|
|
log.Fatalf("Failed to open local database: %v", err)
|
|
}
|
|
defer localDB.Close()
|
|
|
|
if err := localDB.Ping(); err != nil {
|
|
log.Fatalf("Failed to ping local database: %v", err)
|
|
}
|
|
|
|
// Open remote database
|
|
remoteDSN := fmt.Sprintf("%s?authToken=%s", *remoteURL, *authToken)
|
|
remoteDB, err := sql.Open("libsql", remoteDSN)
|
|
if err != nil {
|
|
log.Fatalf("Failed to open remote database: %v", err)
|
|
}
|
|
defer remoteDB.Close()
|
|
|
|
if err := remoteDB.Ping(); err != nil {
|
|
log.Fatalf("Failed to ping remote database: %v", err)
|
|
}
|
|
// Get all user tables from local
|
|
tables, err := getTables(localDB)
|
|
if err != nil {
|
|
log.Fatalf("Failed to list tables: %v", err)
|
|
}
|
|
|
|
if len(tables) == 0 {
|
|
log.Println("No tables found in local database")
|
|
return
|
|
}
|
|
|
|
fmt.Printf("Found %d tables to migrate\n\n", len(tables))
|
|
|
|
start := time.Now()
|
|
|
|
if !*dryRun {
|
|
// Phase 1: Create all tables first so FK references resolve
|
|
fmt.Println("Creating tables...")
|
|
for _, t := range tables {
|
|
if err := createTable(remoteDB, t); err != nil {
|
|
log.Fatalf("Failed to create table %s: %v", t.name, err)
|
|
}
|
|
}
|
|
fmt.Println()
|
|
}
|
|
|
|
// Phase 2: Copy data
|
|
fmt.Println("Migrating data...")
|
|
totalRows := 0
|
|
for _, t := range tables {
|
|
count, err := migrateTable(localDB, remoteDB, t, *batchSize, *skipExisting, *dryRun)
|
|
if err != nil {
|
|
log.Fatalf("Failed to migrate table %s: %v", t.name, err)
|
|
}
|
|
totalRows += count
|
|
}
|
|
|
|
if !*dryRun {
|
|
// Phase 3: Create indexes after data is loaded (faster than indexing during insert)
|
|
fmt.Println("\nCreating indexes...")
|
|
for _, t := range tables {
|
|
if err := createIndexes(localDB, remoteDB, t.name); err != nil {
|
|
log.Fatalf("Failed to create indexes for %s: %v", t.name, err)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
fmt.Printf("\nDone. %d total rows across %d tables in %s\n", totalRows, len(tables), time.Since(start).Round(time.Millisecond))
|
|
if *dryRun {
|
|
fmt.Println("(dry run — nothing was written)")
|
|
}
|
|
}
|
|
|
|
type tableInfo struct {
|
|
name string
|
|
ddl string
|
|
}
|
|
|
|
func getTables(db *sql.DB) ([]tableInfo, error) {
|
|
rows, err := db.Query(`
|
|
SELECT name, sql FROM sqlite_master
|
|
WHERE type = 'table'
|
|
AND name NOT LIKE 'sqlite_%'
|
|
AND name NOT LIKE '_litestream_%'
|
|
AND name NOT LIKE 'libsql_%'
|
|
ORDER BY name
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tables []tableInfo
|
|
for rows.Next() {
|
|
var t tableInfo
|
|
var ddl sql.NullString
|
|
if err := rows.Scan(&t.name, &ddl); err != nil {
|
|
return nil, err
|
|
}
|
|
if ddl.Valid {
|
|
t.ddl = ddl.String
|
|
}
|
|
tables = append(tables, t)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Sort tables so those referenced by foreign keys come first.
|
|
// Tables with FK references depend on other tables existing and
|
|
// having data, so we insert referenced tables first.
|
|
return topoSortTables(db, tables)
|
|
}
|
|
|
|
// topoSortTables orders tables so that referenced (parent) tables come before
|
|
// tables that reference them via foreign keys.
|
|
func topoSortTables(db *sql.DB, tables []tableInfo) ([]tableInfo, error) {
|
|
byName := make(map[string]tableInfo, len(tables))
|
|
for _, t := range tables {
|
|
byName[t.name] = t
|
|
}
|
|
|
|
// Build dependency graph: table -> tables it references
|
|
deps := make(map[string][]string)
|
|
for _, t := range tables {
|
|
fkRows, err := db.Query(fmt.Sprintf("PRAGMA foreign_key_list([%s])", t.name))
|
|
if err != nil {
|
|
// PRAGMA might not return rows for tables without FKs
|
|
continue
|
|
}
|
|
seen := make(map[string]bool)
|
|
for fkRows.Next() {
|
|
var id, seq int
|
|
var table, from, to, onUpdate, onDelete, match string
|
|
if err := fkRows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil {
|
|
fkRows.Close()
|
|
return nil, err
|
|
}
|
|
if !seen[table] {
|
|
deps[t.name] = append(deps[t.name], table)
|
|
seen[table] = true
|
|
}
|
|
}
|
|
fkRows.Close()
|
|
}
|
|
|
|
// Topological sort (Kahn's algorithm)
|
|
visited := make(map[string]bool)
|
|
var sorted []tableInfo
|
|
var visit func(name string)
|
|
visit = func(name string) {
|
|
if visited[name] {
|
|
return
|
|
}
|
|
visited[name] = true
|
|
for _, dep := range deps[name] {
|
|
visit(dep)
|
|
}
|
|
if t, ok := byName[name]; ok {
|
|
sorted = append(sorted, t)
|
|
}
|
|
}
|
|
for _, t := range tables {
|
|
visit(t.name)
|
|
}
|
|
return sorted, nil
|
|
}
|
|
|
|
func getIndexes(db *sql.DB, tableName string) ([]string, error) {
|
|
rows, err := db.Query(`
|
|
SELECT sql FROM sqlite_master
|
|
WHERE type = 'index'
|
|
AND tbl_name = ?
|
|
AND sql IS NOT NULL
|
|
`, tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var indexes []string
|
|
for rows.Next() {
|
|
var ddl string
|
|
if err := rows.Scan(&ddl); err != nil {
|
|
return nil, err
|
|
}
|
|
indexes = append(indexes, ddl)
|
|
}
|
|
return indexes, rows.Err()
|
|
}
|
|
|
|
func createTable(remoteDB *sql.DB, t tableInfo) error {
|
|
if t.ddl == "" {
|
|
return nil
|
|
}
|
|
ddl := t.ddl
|
|
if !strings.Contains(strings.ToUpper(ddl), "IF NOT EXISTS") {
|
|
ddl = strings.Replace(ddl, "CREATE TABLE", "CREATE TABLE IF NOT EXISTS", 1)
|
|
}
|
|
if _, err := remoteDB.Exec(ddl); err != nil {
|
|
return fmt.Errorf("create table %s: %w", t.name, err)
|
|
}
|
|
fmt.Printf(" %s\n", t.name)
|
|
return nil
|
|
}
|
|
|
|
func createIndexes(localDB, remoteDB *sql.DB, tableName string) error {
|
|
indexes, err := getIndexes(localDB, tableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, idx := range indexes {
|
|
ddl := idx
|
|
if !strings.Contains(strings.ToUpper(ddl), "IF NOT EXISTS") {
|
|
ddl = strings.Replace(ddl, "CREATE INDEX", "CREATE INDEX IF NOT EXISTS", 1)
|
|
ddl = strings.Replace(ddl, "CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
|
}
|
|
if _, err := remoteDB.Exec(ddl); err != nil {
|
|
return fmt.Errorf("create index on %s: %w", tableName, err)
|
|
}
|
|
}
|
|
if len(indexes) > 0 {
|
|
fmt.Printf(" %s: %d indexes\n", tableName, len(indexes))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func migrateTable(localDB, remoteDB *sql.DB, t tableInfo, batchSize int, skipExisting, dryRun bool) (int, error) {
|
|
var localCount int
|
|
if err := localDB.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM [%s]", t.name)).Scan(&localCount); err != nil {
|
|
return 0, fmt.Errorf("count local rows: %w", err)
|
|
}
|
|
|
|
if localCount == 0 {
|
|
fmt.Printf(" %-30s %6d rows (empty)\n", t.name, 0)
|
|
return 0, nil
|
|
}
|
|
|
|
if dryRun {
|
|
fmt.Printf(" %-30s %6d rows (would migrate)\n", t.name, localCount)
|
|
return localCount, nil
|
|
}
|
|
|
|
if skipExisting {
|
|
var remoteCount int
|
|
if err := remoteDB.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM [%s]", t.name)).Scan(&remoteCount); err != nil {
|
|
return 0, fmt.Errorf("count remote rows: %w", err)
|
|
}
|
|
if remoteCount > 0 {
|
|
fmt.Printf(" %-30s %6d rows (skipped, %d on remote)\n", t.name, localCount, remoteCount)
|
|
return 0, nil
|
|
}
|
|
}
|
|
|
|
rows, err := localDB.Query(fmt.Sprintf("SELECT * FROM [%s]", t.name))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("select: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("columns: %w", err)
|
|
}
|
|
|
|
placeholders := make([]string, len(cols))
|
|
quotedCols := make([]string, len(cols))
|
|
for i, c := range cols {
|
|
placeholders[i] = "?"
|
|
quotedCols[i] = fmt.Sprintf("[%s]", c)
|
|
}
|
|
insertPrefix := fmt.Sprintf("INSERT INTO [%s] (%s) VALUES ", t.name, strings.Join(quotedCols, ", "))
|
|
rowPlaceholder := "(" + strings.Join(placeholders, ", ") + ")"
|
|
|
|
inserted := 0
|
|
batch := make([][]any, 0, batchSize)
|
|
|
|
for rows.Next() {
|
|
vals := make([]any, len(cols))
|
|
ptrs := make([]any, len(cols))
|
|
for i := range vals {
|
|
ptrs[i] = &vals[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return 0, fmt.Errorf("scan: %w", err)
|
|
}
|
|
batch = append(batch, vals)
|
|
|
|
if len(batch) >= batchSize {
|
|
if err := insertBatch(remoteDB, insertPrefix, rowPlaceholder, batch); err != nil {
|
|
return 0, fmt.Errorf("insert batch at row %d: %w", inserted, err)
|
|
}
|
|
inserted += len(batch)
|
|
batch = batch[:0]
|
|
}
|
|
}
|
|
|
|
if len(batch) > 0 {
|
|
if err := insertBatch(remoteDB, insertPrefix, rowPlaceholder, batch); err != nil {
|
|
return 0, fmt.Errorf("insert final batch: %w", err)
|
|
}
|
|
inserted += len(batch)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return 0, fmt.Errorf("rows iteration: %w", err)
|
|
}
|
|
|
|
fmt.Printf(" %-30s %6d rows migrated\n", t.name, inserted)
|
|
return inserted, nil
|
|
}
|
|
|
|
func insertBatch(db *sql.DB, prefix, rowPlaceholder string, batch [][]any) error {
|
|
if len(batch) == 0 {
|
|
return nil
|
|
}
|
|
|
|
placeholders := make([]string, len(batch))
|
|
var args []any
|
|
for i, row := range batch {
|
|
placeholders[i] = rowPlaceholder
|
|
args = append(args, row...)
|
|
}
|
|
|
|
query := prefix + strings.Join(placeholders, ", ")
|
|
_, err := db.Exec(query, args...)
|
|
return err
|
|
}
|