Files
2026-02-15 22:28:36 -06:00

411 lines
12 KiB
Go

package main
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/spf13/cobra"
)
var updateCmd = &cobra.Command{
Use: "update [target]",
Short: "Deploy updates to servers",
Args: cobra.MaximumNArgs(1),
ValidArgs: []string{"all", "appview", "hold"},
RunE: func(cmd *cobra.Command, args []string) error {
target := "all"
if len(args) > 0 {
target = args[0]
}
withScanner, _ := cmd.Flags().GetBool("with-scanner")
return cmdUpdate(target, withScanner)
},
}
var sshCmd = &cobra.Command{
Use: "ssh <target>",
Short: "SSH into a server",
Args: cobra.ExactArgs(1),
ValidArgs: []string{"appview", "hold"},
RunE: func(cmd *cobra.Command, args []string) error {
return cmdSSH(args[0])
},
}
func init() {
updateCmd.Flags().Bool("with-scanner", false, "Enable and deploy vulnerability scanner alongside hold")
rootCmd.AddCommand(updateCmd)
rootCmd.AddCommand(sshCmd)
}
func cmdUpdate(target string, withScanner bool) error {
state, err := loadState()
if err != nil {
return err
}
naming := state.Naming()
rootDir := projectRoot()
// Enable scanner retroactively via --with-scanner on update
if withScanner && !state.ScannerEnabled {
state.ScannerEnabled = true
if state.ScannerSecret == "" {
secret, err := generateScannerSecret()
if err != nil {
return fmt.Errorf("generate scanner secret: %w", err)
}
state.ScannerSecret = secret
fmt.Printf("Generated scanner shared secret\n")
}
_ = saveState(state)
}
vals := configValsFromState(state)
targets := map[string]struct {
ip string
binaryName string
buildCmd string
localBinary string
serviceName string
healthURL string
configTmpl string
configPath string
unitTmpl string
}{
"appview": {
ip: state.Appview.PublicIP,
binaryName: naming.Appview(),
buildCmd: "appview",
localBinary: "atcr-appview",
serviceName: naming.Appview(),
healthURL: "http://localhost:5000/health",
configTmpl: appviewConfigTmpl,
configPath: naming.AppviewConfigPath(),
unitTmpl: appviewServiceTmpl,
},
"hold": {
ip: state.Hold.PublicIP,
binaryName: naming.Hold(),
buildCmd: "hold",
localBinary: "atcr-hold",
serviceName: naming.Hold(),
healthURL: "http://localhost:8080/xrpc/_health",
configTmpl: holdConfigTmpl,
configPath: naming.HoldConfigPath(),
unitTmpl: holdServiceTmpl,
},
}
var toUpdate []string
switch target {
case "all":
toUpdate = []string{"appview", "hold"}
case "appview", "hold":
toUpdate = []string{target}
default:
return fmt.Errorf("unknown target: %s (use: all, appview, hold)", target)
}
// Run go generate before building
if err := runGenerate(rootDir); err != nil {
return fmt.Errorf("go generate: %w", err)
}
// Build all binaries locally before touching servers
fmt.Println("Building locally (GOOS=linux GOARCH=amd64)...")
for _, name := range toUpdate {
t := targets[name]
outputPath := filepath.Join(rootDir, "bin", t.localBinary)
if err := buildLocal(rootDir, outputPath, "./cmd/"+t.buildCmd); err != nil {
return fmt.Errorf("build %s: %w", name, err)
}
}
// Build scanner locally if needed
needScanner := false
for _, name := range toUpdate {
if name == "hold" && state.ScannerEnabled {
needScanner = true
break
}
}
if needScanner {
outputPath := filepath.Join(rootDir, "bin", "atcr-scanner")
if err := buildLocal(filepath.Join(rootDir, "scanner"), outputPath, "./cmd/scanner"); err != nil {
return fmt.Errorf("build scanner: %w", err)
}
}
// Deploy each target
for _, name := range toUpdate {
t := targets[name]
fmt.Printf("\nDeploying %s (%s)...\n", name, t.ip)
// Sync config keys (adds missing keys from template, never overwrites)
configYAML, err := renderConfig(t.configTmpl, vals)
if err != nil {
return fmt.Errorf("render %s config: %w", name, err)
}
if err := syncConfigKeys(name, t.ip, t.configPath, configYAML); err != nil {
return fmt.Errorf("%s config sync: %w", name, err)
}
// Sync systemd service unit
renderedUnit, err := renderServiceUnit(t.unitTmpl, serviceUnitParams{
DisplayName: naming.DisplayName(),
User: naming.SystemUser(),
BinaryPath: naming.InstallDir() + "/bin/" + t.binaryName,
ConfigPath: t.configPath,
DataDir: naming.BasePath(),
ServiceName: t.serviceName,
})
if err != nil {
return fmt.Errorf("render %s service unit: %w", name, err)
}
unitChanged, err := syncServiceUnit(name, t.ip, t.serviceName, renderedUnit)
if err != nil {
return fmt.Errorf("%s service unit sync: %w", name, err)
}
// Upload binary
localPath := filepath.Join(rootDir, "bin", t.localBinary)
remotePath := naming.InstallDir() + "/bin/" + t.binaryName
if err := scpFile(localPath, t.ip, remotePath); err != nil {
return fmt.Errorf("upload %s: %w", name, err)
}
daemonReload := ""
if unitChanged {
daemonReload = "systemctl daemon-reload"
}
// Scanner additions for hold server
scannerRestart := ""
scannerHealthCheck := ""
if name == "hold" && state.ScannerEnabled {
// Sync scanner config keys
scannerConfigYAML, err := renderConfig(scannerConfigTmpl, vals)
if err != nil {
return fmt.Errorf("render scanner config: %w", err)
}
if err := syncConfigKeys("scanner", t.ip, naming.ScannerConfigPath(), scannerConfigYAML); err != nil {
return fmt.Errorf("scanner config sync: %w", err)
}
// Sync scanner service unit
scannerUnit, err := renderScannerServiceUnit(scannerServiceUnitParams{
DisplayName: naming.DisplayName(),
User: naming.SystemUser(),
BinaryPath: naming.InstallDir() + "/bin/" + naming.Scanner(),
ConfigPath: naming.ScannerConfigPath(),
DataDir: naming.BasePath(),
ServiceName: naming.Scanner(),
HoldServiceName: naming.Hold(),
})
if err != nil {
return fmt.Errorf("render scanner service unit: %w", err)
}
scannerUnitChanged, err := syncServiceUnit("scanner", t.ip, naming.Scanner(), scannerUnit)
if err != nil {
return fmt.Errorf("scanner service unit sync: %w", err)
}
if scannerUnitChanged {
daemonReload = "systemctl daemon-reload"
}
// Upload scanner binary
scannerLocal := filepath.Join(rootDir, "bin", "atcr-scanner")
scannerRemote := naming.InstallDir() + "/bin/" + naming.Scanner()
if err := scpFile(scannerLocal, t.ip, scannerRemote); err != nil {
return fmt.Errorf("upload scanner: %w", err)
}
// Ensure scanner data dirs exist on server
scannerSetup := fmt.Sprintf(`mkdir -p %s/vulndb %s/tmp
chown -R %s:%s %s`,
naming.ScannerDataDir(), naming.ScannerDataDir(),
naming.SystemUser(), naming.SystemUser(), naming.ScannerDataDir())
if _, err := runSSH(t.ip, scannerSetup, false); err != nil {
return fmt.Errorf("scanner dir setup: %w", err)
}
scannerRestart = fmt.Sprintf("\nsystemctl restart %s", naming.Scanner())
scannerHealthCheck = `
sleep 2
curl -sf http://localhost:9090/healthz > /dev/null && echo "SCANNER_HEALTH_OK" || echo "SCANNER_HEALTH_FAIL"
`
}
// Restart services and health check
restartScript := fmt.Sprintf(`set -euo pipefail
%s
systemctl restart %s%s
sleep 2
curl -sf %s > /dev/null && echo "HEALTH_OK" || echo "HEALTH_FAIL"
%s`, daemonReload, t.serviceName, scannerRestart, t.healthURL, scannerHealthCheck)
output, err := runSSH(t.ip, restartScript, true)
if err != nil {
fmt.Printf(" ERROR: %v\n", err)
fmt.Printf(" Output: %s\n", output)
return fmt.Errorf("restart %s failed", name)
}
if strings.Contains(output, "HEALTH_OK") {
fmt.Printf(" %s: updated and healthy\n", name)
} else if strings.Contains(output, "HEALTH_FAIL") {
fmt.Printf(" %s: updated but health check failed!\n", name)
fmt.Printf(" Check: ssh root@%s journalctl -u %s -n 50\n", t.ip, t.serviceName)
} else {
fmt.Printf(" %s: updated (health check inconclusive)\n", name)
}
// Scanner health reporting
if name == "hold" && state.ScannerEnabled {
if strings.Contains(output, "SCANNER_HEALTH_OK") {
fmt.Printf(" scanner: updated and healthy\n")
} else if strings.Contains(output, "SCANNER_HEALTH_FAIL") {
fmt.Printf(" scanner: updated but health check failed!\n")
fmt.Printf(" Check: ssh root@%s journalctl -u %s -n 50\n", t.ip, naming.Scanner())
}
}
}
return nil
}
// configValsFromState builds ConfigValues from persisted state.
// S3SecretKey is intentionally left empty — syncConfigKeys only adds missing
// keys and never overwrites, so the server's existing secret is preserved.
func configValsFromState(state *InfraState) *ConfigValues {
naming := state.Naming()
_, baseDomain, _, _ := extractFromAppviewTemplate()
holdDomain := state.Zone + ".cove." + baseDomain
return &ConfigValues{
S3Endpoint: state.ObjectStorage.Endpoint,
S3Region: state.ObjectStorage.Region,
S3Bucket: state.ObjectStorage.Bucket,
S3AccessKey: state.ObjectStorage.AccessKeyID,
S3SecretKey: "", // not persisted in state; existing value on server is preserved
Zone: state.Zone,
HoldDomain: holdDomain,
HoldDid: "did:web:" + holdDomain,
BasePath: naming.BasePath(),
ScannerSecret: state.ScannerSecret,
}
}
// runGenerate runs go generate ./... in the given directory using host OS/arch
// (no cross-compilation env vars — generate tools must run on the build machine).
func runGenerate(dir string) error {
fmt.Println("Running go generate ./...")
cmd := exec.Command("go", "generate", "./...")
cmd.Dir = dir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// buildLocal compiles a Go binary locally with cross-compilation flags for linux/amd64.
func buildLocal(dir, outputPath, buildPkg string) error {
fmt.Printf(" building %s...\n", filepath.Base(outputPath))
cmd := exec.Command("go", "build",
"-ldflags=-s -w",
"-trimpath",
"-o", outputPath,
buildPkg,
)
cmd.Dir = dir
cmd.Env = append(os.Environ(),
"GOOS=linux",
"GOARCH=amd64",
"CGO_ENABLED=1",
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// scpFile uploads a local file to a remote server via SCP.
// Removes the remote file first to avoid ETXTBSY when overwriting a running binary.
func scpFile(localPath, ip, remotePath string) error {
fmt.Printf(" uploading %s → %s:%s\n", filepath.Base(localPath), ip, remotePath)
_, _ = runSSH(ip, fmt.Sprintf("rm -f %s", remotePath), false)
cmd := exec.Command("scp",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ConnectTimeout=10",
localPath,
"root@"+ip+":"+remotePath,
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func cmdSSH(target string) error {
state, err := loadState()
if err != nil {
return err
}
var ip string
switch target {
case "appview":
ip = state.Appview.PublicIP
case "hold":
ip = state.Hold.PublicIP
default:
return fmt.Errorf("unknown target: %s (use: appview, hold)", target)
}
fmt.Printf("Connecting to %s (%s)...\n", target, ip)
cmd := exec.Command("ssh",
"-o", "StrictHostKeyChecking=accept-new",
"root@"+ip,
)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func runSSH(ip, script string, stream bool) (string, error) {
cmd := exec.Command("ssh",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ConnectTimeout=10",
"root@"+ip,
"bash -s",
)
cmd.Stdin = strings.NewReader(script)
var buf bytes.Buffer
if stream {
cmd.Stdout = io.MultiWriter(os.Stdout, &buf)
cmd.Stderr = io.MultiWriter(os.Stderr, &buf)
} else {
cmd.Stdout = &buf
cmd.Stderr = &buf
}
// Give deploys up to 5 minutes (SCP + restart, much faster than remote builds)
done := make(chan error, 1)
go func() { done <- cmd.Run() }()
select {
case err := <-done:
return buf.String(), err
case <-time.After(5 * time.Minute):
_ = cmd.Process.Kill()
return buf.String(), fmt.Errorf("SSH command timed out after 5 minutes")
}
}