411 lines
12 KiB
Go
411 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var updateCmd = &cobra.Command{
|
|
Use: "update [target]",
|
|
Short: "Deploy updates to servers",
|
|
Args: cobra.MaximumNArgs(1),
|
|
ValidArgs: []string{"all", "appview", "hold"},
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
target := "all"
|
|
if len(args) > 0 {
|
|
target = args[0]
|
|
}
|
|
withScanner, _ := cmd.Flags().GetBool("with-scanner")
|
|
return cmdUpdate(target, withScanner)
|
|
},
|
|
}
|
|
|
|
var sshCmd = &cobra.Command{
|
|
Use: "ssh <target>",
|
|
Short: "SSH into a server",
|
|
Args: cobra.ExactArgs(1),
|
|
ValidArgs: []string{"appview", "hold"},
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
return cmdSSH(args[0])
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
updateCmd.Flags().Bool("with-scanner", false, "Enable and deploy vulnerability scanner alongside hold")
|
|
rootCmd.AddCommand(updateCmd)
|
|
rootCmd.AddCommand(sshCmd)
|
|
}
|
|
|
|
func cmdUpdate(target string, withScanner bool) error {
|
|
state, err := loadState()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
naming := state.Naming()
|
|
rootDir := projectRoot()
|
|
|
|
// Enable scanner retroactively via --with-scanner on update
|
|
if withScanner && !state.ScannerEnabled {
|
|
state.ScannerEnabled = true
|
|
if state.ScannerSecret == "" {
|
|
secret, err := generateScannerSecret()
|
|
if err != nil {
|
|
return fmt.Errorf("generate scanner secret: %w", err)
|
|
}
|
|
state.ScannerSecret = secret
|
|
fmt.Printf("Generated scanner shared secret\n")
|
|
}
|
|
_ = saveState(state)
|
|
}
|
|
|
|
vals := configValsFromState(state)
|
|
|
|
targets := map[string]struct {
|
|
ip string
|
|
binaryName string
|
|
buildCmd string
|
|
localBinary string
|
|
serviceName string
|
|
healthURL string
|
|
configTmpl string
|
|
configPath string
|
|
unitTmpl string
|
|
}{
|
|
"appview": {
|
|
ip: state.Appview.PublicIP,
|
|
binaryName: naming.Appview(),
|
|
buildCmd: "appview",
|
|
localBinary: "atcr-appview",
|
|
serviceName: naming.Appview(),
|
|
healthURL: "http://localhost:5000/health",
|
|
configTmpl: appviewConfigTmpl,
|
|
configPath: naming.AppviewConfigPath(),
|
|
unitTmpl: appviewServiceTmpl,
|
|
},
|
|
"hold": {
|
|
ip: state.Hold.PublicIP,
|
|
binaryName: naming.Hold(),
|
|
buildCmd: "hold",
|
|
localBinary: "atcr-hold",
|
|
serviceName: naming.Hold(),
|
|
healthURL: "http://localhost:8080/xrpc/_health",
|
|
configTmpl: holdConfigTmpl,
|
|
configPath: naming.HoldConfigPath(),
|
|
unitTmpl: holdServiceTmpl,
|
|
},
|
|
}
|
|
|
|
var toUpdate []string
|
|
switch target {
|
|
case "all":
|
|
toUpdate = []string{"appview", "hold"}
|
|
case "appview", "hold":
|
|
toUpdate = []string{target}
|
|
default:
|
|
return fmt.Errorf("unknown target: %s (use: all, appview, hold)", target)
|
|
}
|
|
|
|
// Run go generate before building
|
|
if err := runGenerate(rootDir); err != nil {
|
|
return fmt.Errorf("go generate: %w", err)
|
|
}
|
|
|
|
// Build all binaries locally before touching servers
|
|
fmt.Println("Building locally (GOOS=linux GOARCH=amd64)...")
|
|
for _, name := range toUpdate {
|
|
t := targets[name]
|
|
outputPath := filepath.Join(rootDir, "bin", t.localBinary)
|
|
if err := buildLocal(rootDir, outputPath, "./cmd/"+t.buildCmd); err != nil {
|
|
return fmt.Errorf("build %s: %w", name, err)
|
|
}
|
|
}
|
|
|
|
// Build scanner locally if needed
|
|
needScanner := false
|
|
for _, name := range toUpdate {
|
|
if name == "hold" && state.ScannerEnabled {
|
|
needScanner = true
|
|
break
|
|
}
|
|
}
|
|
if needScanner {
|
|
outputPath := filepath.Join(rootDir, "bin", "atcr-scanner")
|
|
if err := buildLocal(filepath.Join(rootDir, "scanner"), outputPath, "./cmd/scanner"); err != nil {
|
|
return fmt.Errorf("build scanner: %w", err)
|
|
}
|
|
}
|
|
|
|
// Deploy each target
|
|
for _, name := range toUpdate {
|
|
t := targets[name]
|
|
fmt.Printf("\nDeploying %s (%s)...\n", name, t.ip)
|
|
|
|
// Sync config keys (adds missing keys from template, never overwrites)
|
|
configYAML, err := renderConfig(t.configTmpl, vals)
|
|
if err != nil {
|
|
return fmt.Errorf("render %s config: %w", name, err)
|
|
}
|
|
if err := syncConfigKeys(name, t.ip, t.configPath, configYAML); err != nil {
|
|
return fmt.Errorf("%s config sync: %w", name, err)
|
|
}
|
|
|
|
// Sync systemd service unit
|
|
renderedUnit, err := renderServiceUnit(t.unitTmpl, serviceUnitParams{
|
|
DisplayName: naming.DisplayName(),
|
|
User: naming.SystemUser(),
|
|
BinaryPath: naming.InstallDir() + "/bin/" + t.binaryName,
|
|
ConfigPath: t.configPath,
|
|
DataDir: naming.BasePath(),
|
|
ServiceName: t.serviceName,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("render %s service unit: %w", name, err)
|
|
}
|
|
unitChanged, err := syncServiceUnit(name, t.ip, t.serviceName, renderedUnit)
|
|
if err != nil {
|
|
return fmt.Errorf("%s service unit sync: %w", name, err)
|
|
}
|
|
|
|
// Upload binary
|
|
localPath := filepath.Join(rootDir, "bin", t.localBinary)
|
|
remotePath := naming.InstallDir() + "/bin/" + t.binaryName
|
|
if err := scpFile(localPath, t.ip, remotePath); err != nil {
|
|
return fmt.Errorf("upload %s: %w", name, err)
|
|
}
|
|
|
|
daemonReload := ""
|
|
if unitChanged {
|
|
daemonReload = "systemctl daemon-reload"
|
|
}
|
|
|
|
// Scanner additions for hold server
|
|
scannerRestart := ""
|
|
scannerHealthCheck := ""
|
|
if name == "hold" && state.ScannerEnabled {
|
|
// Sync scanner config keys
|
|
scannerConfigYAML, err := renderConfig(scannerConfigTmpl, vals)
|
|
if err != nil {
|
|
return fmt.Errorf("render scanner config: %w", err)
|
|
}
|
|
if err := syncConfigKeys("scanner", t.ip, naming.ScannerConfigPath(), scannerConfigYAML); err != nil {
|
|
return fmt.Errorf("scanner config sync: %w", err)
|
|
}
|
|
|
|
// Sync scanner service unit
|
|
scannerUnit, err := renderScannerServiceUnit(scannerServiceUnitParams{
|
|
DisplayName: naming.DisplayName(),
|
|
User: naming.SystemUser(),
|
|
BinaryPath: naming.InstallDir() + "/bin/" + naming.Scanner(),
|
|
ConfigPath: naming.ScannerConfigPath(),
|
|
DataDir: naming.BasePath(),
|
|
ServiceName: naming.Scanner(),
|
|
HoldServiceName: naming.Hold(),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("render scanner service unit: %w", err)
|
|
}
|
|
scannerUnitChanged, err := syncServiceUnit("scanner", t.ip, naming.Scanner(), scannerUnit)
|
|
if err != nil {
|
|
return fmt.Errorf("scanner service unit sync: %w", err)
|
|
}
|
|
if scannerUnitChanged {
|
|
daemonReload = "systemctl daemon-reload"
|
|
}
|
|
|
|
// Upload scanner binary
|
|
scannerLocal := filepath.Join(rootDir, "bin", "atcr-scanner")
|
|
scannerRemote := naming.InstallDir() + "/bin/" + naming.Scanner()
|
|
if err := scpFile(scannerLocal, t.ip, scannerRemote); err != nil {
|
|
return fmt.Errorf("upload scanner: %w", err)
|
|
}
|
|
|
|
// Ensure scanner data dirs exist on server
|
|
scannerSetup := fmt.Sprintf(`mkdir -p %s/vulndb %s/tmp
|
|
chown -R %s:%s %s`,
|
|
naming.ScannerDataDir(), naming.ScannerDataDir(),
|
|
naming.SystemUser(), naming.SystemUser(), naming.ScannerDataDir())
|
|
if _, err := runSSH(t.ip, scannerSetup, false); err != nil {
|
|
return fmt.Errorf("scanner dir setup: %w", err)
|
|
}
|
|
|
|
scannerRestart = fmt.Sprintf("\nsystemctl restart %s", naming.Scanner())
|
|
scannerHealthCheck = `
|
|
sleep 2
|
|
curl -sf http://localhost:9090/healthz > /dev/null && echo "SCANNER_HEALTH_OK" || echo "SCANNER_HEALTH_FAIL"
|
|
`
|
|
}
|
|
|
|
// Restart services and health check
|
|
restartScript := fmt.Sprintf(`set -euo pipefail
|
|
%s
|
|
systemctl restart %s%s
|
|
sleep 2
|
|
curl -sf %s > /dev/null && echo "HEALTH_OK" || echo "HEALTH_FAIL"
|
|
%s`, daemonReload, t.serviceName, scannerRestart, t.healthURL, scannerHealthCheck)
|
|
|
|
output, err := runSSH(t.ip, restartScript, true)
|
|
if err != nil {
|
|
fmt.Printf(" ERROR: %v\n", err)
|
|
fmt.Printf(" Output: %s\n", output)
|
|
return fmt.Errorf("restart %s failed", name)
|
|
}
|
|
|
|
if strings.Contains(output, "HEALTH_OK") {
|
|
fmt.Printf(" %s: updated and healthy\n", name)
|
|
} else if strings.Contains(output, "HEALTH_FAIL") {
|
|
fmt.Printf(" %s: updated but health check failed!\n", name)
|
|
fmt.Printf(" Check: ssh root@%s journalctl -u %s -n 50\n", t.ip, t.serviceName)
|
|
} else {
|
|
fmt.Printf(" %s: updated (health check inconclusive)\n", name)
|
|
}
|
|
|
|
// Scanner health reporting
|
|
if name == "hold" && state.ScannerEnabled {
|
|
if strings.Contains(output, "SCANNER_HEALTH_OK") {
|
|
fmt.Printf(" scanner: updated and healthy\n")
|
|
} else if strings.Contains(output, "SCANNER_HEALTH_FAIL") {
|
|
fmt.Printf(" scanner: updated but health check failed!\n")
|
|
fmt.Printf(" Check: ssh root@%s journalctl -u %s -n 50\n", t.ip, naming.Scanner())
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// configValsFromState builds ConfigValues from persisted state.
|
|
// S3SecretKey is intentionally left empty — syncConfigKeys only adds missing
|
|
// keys and never overwrites, so the server's existing secret is preserved.
|
|
func configValsFromState(state *InfraState) *ConfigValues {
|
|
naming := state.Naming()
|
|
_, baseDomain, _, _ := extractFromAppviewTemplate()
|
|
holdDomain := state.Zone + ".cove." + baseDomain
|
|
|
|
return &ConfigValues{
|
|
S3Endpoint: state.ObjectStorage.Endpoint,
|
|
S3Region: state.ObjectStorage.Region,
|
|
S3Bucket: state.ObjectStorage.Bucket,
|
|
S3AccessKey: state.ObjectStorage.AccessKeyID,
|
|
S3SecretKey: "", // not persisted in state; existing value on server is preserved
|
|
Zone: state.Zone,
|
|
HoldDomain: holdDomain,
|
|
HoldDid: "did:web:" + holdDomain,
|
|
BasePath: naming.BasePath(),
|
|
ScannerSecret: state.ScannerSecret,
|
|
}
|
|
}
|
|
|
|
// runGenerate runs go generate ./... in the given directory using host OS/arch
|
|
// (no cross-compilation env vars — generate tools must run on the build machine).
|
|
func runGenerate(dir string) error {
|
|
fmt.Println("Running go generate ./...")
|
|
cmd := exec.Command("go", "generate", "./...")
|
|
cmd.Dir = dir
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
return cmd.Run()
|
|
}
|
|
|
|
// buildLocal compiles a Go binary locally with cross-compilation flags for linux/amd64.
|
|
func buildLocal(dir, outputPath, buildPkg string) error {
|
|
fmt.Printf(" building %s...\n", filepath.Base(outputPath))
|
|
cmd := exec.Command("go", "build",
|
|
"-ldflags=-s -w",
|
|
"-trimpath",
|
|
"-o", outputPath,
|
|
buildPkg,
|
|
)
|
|
cmd.Dir = dir
|
|
cmd.Env = append(os.Environ(),
|
|
"GOOS=linux",
|
|
"GOARCH=amd64",
|
|
"CGO_ENABLED=1",
|
|
)
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
return cmd.Run()
|
|
}
|
|
|
|
// scpFile uploads a local file to a remote server via SCP.
|
|
// Removes the remote file first to avoid ETXTBSY when overwriting a running binary.
|
|
func scpFile(localPath, ip, remotePath string) error {
|
|
fmt.Printf(" uploading %s → %s:%s\n", filepath.Base(localPath), ip, remotePath)
|
|
_, _ = runSSH(ip, fmt.Sprintf("rm -f %s", remotePath), false)
|
|
cmd := exec.Command("scp",
|
|
"-o", "StrictHostKeyChecking=accept-new",
|
|
"-o", "ConnectTimeout=10",
|
|
localPath,
|
|
"root@"+ip+":"+remotePath,
|
|
)
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
return cmd.Run()
|
|
}
|
|
|
|
func cmdSSH(target string) error {
|
|
state, err := loadState()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var ip string
|
|
switch target {
|
|
case "appview":
|
|
ip = state.Appview.PublicIP
|
|
case "hold":
|
|
ip = state.Hold.PublicIP
|
|
default:
|
|
return fmt.Errorf("unknown target: %s (use: appview, hold)", target)
|
|
}
|
|
|
|
fmt.Printf("Connecting to %s (%s)...\n", target, ip)
|
|
cmd := exec.Command("ssh",
|
|
"-o", "StrictHostKeyChecking=accept-new",
|
|
"root@"+ip,
|
|
)
|
|
cmd.Stdin = os.Stdin
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
return cmd.Run()
|
|
}
|
|
|
|
func runSSH(ip, script string, stream bool) (string, error) {
|
|
cmd := exec.Command("ssh",
|
|
"-o", "StrictHostKeyChecking=accept-new",
|
|
"-o", "ConnectTimeout=10",
|
|
"root@"+ip,
|
|
"bash -s",
|
|
)
|
|
cmd.Stdin = strings.NewReader(script)
|
|
|
|
var buf bytes.Buffer
|
|
if stream {
|
|
cmd.Stdout = io.MultiWriter(os.Stdout, &buf)
|
|
cmd.Stderr = io.MultiWriter(os.Stderr, &buf)
|
|
} else {
|
|
cmd.Stdout = &buf
|
|
cmd.Stderr = &buf
|
|
}
|
|
|
|
// Give deploys up to 5 minutes (SCP + restart, much faster than remote builds)
|
|
done := make(chan error, 1)
|
|
go func() { done <- cmd.Run() }()
|
|
|
|
select {
|
|
case err := <-done:
|
|
return buf.String(), err
|
|
case <-time.After(5 * time.Minute):
|
|
_ = cmd.Process.Kill()
|
|
return buf.String(), fmt.Errorf("SSH command timed out after 5 minutes")
|
|
}
|
|
}
|