Files
at-container-registry/cmd/credential-helper/main.go
2026-01-06 23:56:17 -06:00

1071 lines
31 KiB
Go

package main
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
// DeviceConfig represents the stored device configuration
type DeviceConfig struct {
Handle string `json:"handle"`
DeviceSecret string `json:"device_secret"`
AppViewURL string `json:"appview_url"`
}
// DeviceCredentials stores multiple device configurations keyed by AppView URL
type DeviceCredentials struct {
Credentials map[string]DeviceConfig `json:"credentials"`
}
// DockerDaemonConfig represents Docker's daemon.json configuration
type DockerDaemonConfig struct {
InsecureRegistries []string `json:"insecure-registries"`
}
// Docker credential helper protocol
// https://github.com/docker/docker-credential-helpers
// Credentials represents docker credentials
type Credentials struct {
ServerURL string `json:"ServerURL,omitempty"`
Username string `json:"Username,omitempty"`
Secret string `json:"Secret,omitempty"`
}
// Device authorization API types
type DeviceCodeRequest struct {
DeviceName string `json:"device_name"`
}
type DeviceCodeResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}
type DeviceTokenRequest struct {
DeviceCode string `json:"device_code"`
}
type DeviceTokenResponse struct {
DeviceSecret string `json:"device_secret,omitempty"`
Handle string `json:"handle,omitempty"`
DID string `json:"did,omitempty"`
Error string `json:"error,omitempty"`
}
// AuthErrorResponse is the JSON error response from /auth/token
type AuthErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
LoginURL string `json:"login_url,omitempty"`
}
// ValidationResult represents the result of credential validation
type ValidationResult struct {
Valid bool
OAuthSessionExpired bool
LoginURL string
}
// VersionAPIResponse is the response from /api/credential-helper/version
type VersionAPIResponse struct {
Latest string `json:"latest"`
DownloadURLs map[string]string `json:"download_urls"`
Checksums map[string]string `json:"checksums"`
ReleaseNotes string `json:"release_notes,omitempty"`
}
// UpdateCheckCache stores the last update check result
type UpdateCheckCache struct {
CheckedAt time.Time `json:"checked_at"`
Latest string `json:"latest"`
Current string `json:"current"`
}
var (
version = "dev"
commit = "none"
date = "unknown"
// Update check cache TTL (24 hours)
updateCheckCacheTTL = 24 * time.Hour
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintf(os.Stderr, "Usage: docker-credential-atcr <get|store|erase|version|update>\n")
os.Exit(1)
}
command := os.Args[1]
switch command {
case "get":
handleGet()
case "store":
handleStore()
case "erase":
handleErase()
case "version":
fmt.Printf("docker-credential-atcr %s (commit: %s, built: %s)\n", version, commit, date)
case "update":
checkOnly := len(os.Args) > 2 && os.Args[2] == "--check"
handleUpdate(checkOnly)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(1)
}
}
// handleGet retrieves credentials for the given server
func handleGet() {
// Docker sends the server URL as a plain string on stdin (not JSON)
var serverURL string
if _, err := fmt.Fscanln(os.Stdin, &serverURL); err != nil {
fmt.Fprintf(os.Stderr, "Error reading server URL: %v\n", err)
os.Exit(1)
}
// Build AppView URL to use as lookup key
appViewURL := buildAppViewURL(serverURL)
// Load all device credentials
configPath := getConfigPath()
allCreds, err := loadDeviceCredentials(configPath)
if err != nil {
// No credentials file exists yet
allCreds = &DeviceCredentials{
Credentials: make(map[string]DeviceConfig),
}
}
// Look up device config for this specific AppView URL
deviceConfig, found := getDeviceConfig(allCreds, appViewURL)
// If credentials exist, validate them
if found && deviceConfig.DeviceSecret != "" {
result := validateCredentials(appViewURL, deviceConfig.Handle, deviceConfig.DeviceSecret)
if !result.Valid {
if result.OAuthSessionExpired {
// OAuth session expired - need to re-authenticate via browser
// Device secret is still valid, just need to restore OAuth session
fmt.Fprintf(os.Stderr, "OAuth session expired. Opening browser to re-authenticate...\n")
loginURL := result.LoginURL
if loginURL == "" {
loginURL = appViewURL + "/auth/oauth/login"
}
// Try to open browser
if err := openBrowser(loginURL); err != nil {
fmt.Fprintf(os.Stderr, "Could not open browser automatically.\n")
fmt.Fprintf(os.Stderr, "Please visit: %s\n", loginURL)
} else {
fmt.Fprintf(os.Stderr, "Please complete authentication in your browser.\n")
}
// Wait for user to complete OAuth flow, then retry
fmt.Fprintf(os.Stderr, "Waiting for authentication")
for range 60 { // Wait up to 2 minutes
time.Sleep(2 * time.Second)
fmt.Fprintf(os.Stderr, ".")
// Retry validation
retryResult := validateCredentials(appViewURL, deviceConfig.Handle, deviceConfig.DeviceSecret)
if retryResult.Valid {
fmt.Fprintf(os.Stderr, "\n✓ Re-authenticated successfully!\n")
goto credentialsValid
}
}
fmt.Fprintf(os.Stderr, "\nAuthentication timed out. Please try again.\n")
os.Exit(1)
}
// Generic auth failure - delete credentials and re-authorize
fmt.Fprintf(os.Stderr, "Stored credentials for %s are invalid or expired\n", appViewURL)
// Delete the invalid credentials
delete(allCreds.Credentials, appViewURL)
if err := saveDeviceCredentials(configPath, allCreds); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to save updated credentials: %v\n", err)
}
// Mark as not found so we re-authorize below
found = false
}
}
credentialsValid:
if !found || deviceConfig.DeviceSecret == "" {
// No credentials for this AppView
// Check if we should attempt interactive authorization
// We only do this if:
// 1. ATCR_AUTO_AUTH environment variable is set to "1", OR
// 2. We're in an interactive terminal (stderr is a terminal)
shouldAutoAuth := os.Getenv("ATCR_AUTO_AUTH") == "1" || isTerminal(os.Stderr)
if !shouldAutoAuth {
fmt.Fprintf(os.Stderr, "No valid credentials found for %s\n", appViewURL)
fmt.Fprintf(os.Stderr, "\nTo authenticate, run:\n")
fmt.Fprintf(os.Stderr, " export ATCR_AUTO_AUTH=1\n")
fmt.Fprintf(os.Stderr, " docker push %s/<user>/<image>:<tag>\n", serverURL)
fmt.Fprintf(os.Stderr, "\nThis will trigger device authorization in your browser.\n")
os.Exit(1)
}
// Auto-auth enabled - trigger device authorization
fmt.Fprintf(os.Stderr, "Starting device authorization for %s...\n", appViewURL)
newConfig, err := authorizeDevice(serverURL)
if err != nil {
fmt.Fprintf(os.Stderr, "Device authorization failed: %v\n", err)
fmt.Fprintf(os.Stderr, "\nFallback: Use 'docker login %s' with your ATProto app-password\n", serverURL)
os.Exit(1)
}
// Save device configuration
if err := saveDeviceConfig(configPath, newConfig); err != nil {
fmt.Fprintf(os.Stderr, "Failed to save device config: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "✓ Device authorized successfully for %s!\n", appViewURL)
deviceConfig = newConfig
}
// Check for updates (non-blocking due to 24h cache)
checkAndNotifyUpdate(appViewURL)
// Return credentials for Docker
creds := Credentials{
ServerURL: serverURL,
Username: deviceConfig.Handle,
Secret: deviceConfig.DeviceSecret,
}
if err := json.NewEncoder(os.Stdout).Encode(creds); err != nil {
fmt.Fprintf(os.Stderr, "Error encoding response: %v\n", err)
os.Exit(1)
}
}
// handleStore stores credentials (Docker calls this after login)
func handleStore() {
var creds Credentials
if err := json.NewDecoder(os.Stdin).Decode(&creds); err != nil {
fmt.Fprintf(os.Stderr, "Error decoding credentials: %v\n", err)
os.Exit(1)
}
// This is a no-op for the device auth flow
// Users should use the automatic device authorization, not docker login
// If they use docker login with app-password, that goes through /auth/token directly
}
// handleErase removes stored credentials for a specific AppView
func handleErase() {
// Docker sends the server URL as a plain string on stdin (not JSON)
var serverURL string
if _, err := fmt.Fscanln(os.Stdin, &serverURL); err != nil {
fmt.Fprintf(os.Stderr, "Error reading server URL: %v\n", err)
os.Exit(1)
}
// Build AppView URL to use as lookup key
appViewURL := buildAppViewURL(serverURL)
// Load all device credentials
configPath := getConfigPath()
allCreds, err := loadDeviceCredentials(configPath)
if err != nil {
// No credentials file exists, nothing to erase
return
}
// Remove the specific AppView URL's credentials
delete(allCreds.Credentials, appViewURL)
// If no credentials remain, remove the file entirely
if len(allCreds.Credentials) == 0 {
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error removing device config: %v\n", err)
os.Exit(1)
}
return
}
// Otherwise, save the updated credentials
if err := saveDeviceCredentials(configPath, allCreds); err != nil {
fmt.Fprintf(os.Stderr, "Error saving device config: %v\n", err)
os.Exit(1)
}
}
// authorizeDevice performs the device authorization flow
func authorizeDevice(serverURL string) (*DeviceConfig, error) {
appViewURL := buildAppViewURL(serverURL)
// Get device name (hostname)
deviceName, err := os.Hostname()
if err != nil {
deviceName = "Unknown Device"
}
// 1. Request device code
fmt.Fprintf(os.Stderr, "Requesting device authorization...\n")
reqBody, _ := json.Marshal(DeviceCodeRequest{DeviceName: deviceName})
resp, err := http.Post(appViewURL+"/auth/device/code", "application/json", bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to request device code: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
var codeResp DeviceCodeResponse
if err := json.NewDecoder(resp.Body).Decode(&codeResp); err != nil {
return nil, fmt.Errorf("failed to decode device code response: %w", err)
}
// 2. Display authorization URL and user code
verificationURL := codeResp.VerificationURI + "?user_code=" + codeResp.UserCode
fmt.Fprintf(os.Stderr, "\n╔════════════════════════════════════════════════════════════════╗\n")
fmt.Fprintf(os.Stderr, "║ Device Authorization Required ║\n")
fmt.Fprintf(os.Stderr, "╚════════════════════════════════════════════════════════════════╝\n\n")
fmt.Fprintf(os.Stderr, "Visit this URL in your browser:\n")
fmt.Fprintf(os.Stderr, " %s\n\n", verificationURL)
fmt.Fprintf(os.Stderr, "Your code: %s\n\n", codeResp.UserCode)
// Try to open browser (may fail on headless systems)
if err := openBrowser(verificationURL); err == nil {
fmt.Fprintf(os.Stderr, "Opening browser...\n\n")
} else {
fmt.Fprintf(os.Stderr, "Could not open browser automatically (%v)\n", err)
fmt.Fprintf(os.Stderr, "Please open the URL above manually.\n\n")
}
fmt.Fprintf(os.Stderr, "Waiting for authorization")
// 3. Poll for authorization completion
pollInterval := time.Duration(codeResp.Interval) * time.Second
timeout := time.Duration(codeResp.ExpiresIn) * time.Second
deadline := time.Now().Add(timeout)
dots := 0
for time.Now().Before(deadline) {
time.Sleep(pollInterval)
// Show progress dots
dots = (dots + 1) % 4
fmt.Fprintf(os.Stderr, "\rWaiting for authorization%s ", strings.Repeat(".", dots))
// Poll token endpoint
tokenReqBody, _ := json.Marshal(DeviceTokenRequest{DeviceCode: codeResp.DeviceCode})
tokenResp, err := http.Post(appViewURL+"/auth/device/token", "application/json", bytes.NewReader(tokenReqBody))
if err != nil {
fmt.Fprintf(os.Stderr, "\nPoll failed: %v\n", err)
continue
}
var tokenResult DeviceTokenResponse
if err := json.NewDecoder(tokenResp.Body).Decode(&tokenResult); err != nil {
fmt.Fprintf(os.Stderr, "\nFailed to decode response: %v\n", err)
tokenResp.Body.Close()
continue
}
tokenResp.Body.Close()
if tokenResult.Error == "authorization_pending" {
// Still waiting
continue
}
if tokenResult.Error != "" {
fmt.Fprintf(os.Stderr, "\n")
return nil, fmt.Errorf("authorization failed: %s", tokenResult.Error)
}
// Success!
fmt.Fprintf(os.Stderr, "\n")
return &DeviceConfig{
Handle: tokenResult.Handle,
DeviceSecret: tokenResult.DeviceSecret,
AppViewURL: appViewURL,
}, nil
}
fmt.Fprintf(os.Stderr, "\n")
return nil, fmt.Errorf("authorization timeout")
}
// getConfigPath returns the path to the device configuration file
func getConfigPath() string {
homeDir, err := os.UserHomeDir()
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting home directory: %v\n", err)
os.Exit(1)
}
atcrDir := filepath.Join(homeDir, ".atcr")
if err := os.MkdirAll(atcrDir, 0700); err != nil {
fmt.Fprintf(os.Stderr, "Error creating .atcr directory: %v\n", err)
os.Exit(1)
}
return filepath.Join(atcrDir, "device.json")
}
// loadDeviceCredentials loads all device credentials from disk
func loadDeviceCredentials(path string) (*DeviceCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Try to unmarshal as new format (map of credentials)
var creds DeviceCredentials
if err := json.Unmarshal(data, &creds); err == nil && creds.Credentials != nil {
return &creds, nil
}
// Backward compatibility: Try to unmarshal as old format (single config)
var oldConfig DeviceConfig
if err := json.Unmarshal(data, &oldConfig); err == nil && oldConfig.DeviceSecret != "" {
// Migrate old format to new format
creds = DeviceCredentials{
Credentials: map[string]DeviceConfig{
oldConfig.AppViewURL: oldConfig,
},
}
return &creds, nil
}
return nil, fmt.Errorf("invalid device credentials format")
}
// getDeviceConfig retrieves a specific device config for an AppView URL
func getDeviceConfig(creds *DeviceCredentials, appViewURL string) (*DeviceConfig, bool) {
if creds == nil || creds.Credentials == nil {
return nil, false
}
config, found := creds.Credentials[appViewURL]
return &config, found
}
// saveDeviceCredentials saves all device credentials to disk
func saveDeviceCredentials(path string, creds *DeviceCredentials) error {
data, err := json.MarshalIndent(creds, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
// saveDeviceConfig saves a single device config by adding/updating it in the credentials map
func saveDeviceConfig(path string, config *DeviceConfig) error {
// Load existing credentials (or create new)
creds, err := loadDeviceCredentials(path)
if err != nil {
// Create new credentials structure
creds = &DeviceCredentials{
Credentials: make(map[string]DeviceConfig),
}
}
// Add or update the config for this AppView URL
creds.Credentials[config.AppViewURL] = *config
// Save back to disk
return saveDeviceCredentials(path, creds)
}
// openBrowser opens the specified URL in the default browser
func openBrowser(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "linux":
cmd = exec.Command("xdg-open", url)
case "darwin":
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
default:
return fmt.Errorf("unsupported platform")
}
return cmd.Start()
}
// buildAppViewURL constructs the AppView URL with the appropriate protocol
func buildAppViewURL(serverURL string) string {
// If serverURL already has a scheme, use it as-is
if strings.HasPrefix(serverURL, "http://") || strings.HasPrefix(serverURL, "https://") {
return serverURL
}
// Determine protocol based on Docker configuration and heuristics
if isInsecureRegistry(serverURL) {
return "http://" + serverURL
}
// Default to HTTPS (mirrors Docker's default behavior)
return "https://" + serverURL
}
// isInsecureRegistry checks if a registry should use HTTP instead of HTTPS
func isInsecureRegistry(serverURL string) bool {
// Check Docker's insecure-registries configuration
insecureRegistries := getDockerInsecureRegistries()
for _, reg := range insecureRegistries {
// Match exact serverURL or just the host part
if reg == serverURL || reg == stripPort(serverURL) {
return true
}
}
// Fallback heuristics: localhost and private IPs
host := stripPort(serverURL)
// Check for localhost variants
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
return true
}
// Check if it's a private IP address
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() {
return true
}
}
return false
}
// getDockerInsecureRegistries reads Docker's insecure-registries configuration
func getDockerInsecureRegistries() []string {
var paths []string
// Common Docker daemon.json locations
switch runtime.GOOS {
case "windows":
programData := os.Getenv("ProgramData")
if programData != "" {
paths = append(paths, filepath.Join(programData, "docker", "config", "daemon.json"))
}
default:
// Linux and macOS
paths = append(paths, "/etc/docker/daemon.json")
if homeDir, err := os.UserHomeDir(); err == nil {
// Rootless Docker location
paths = append(paths, filepath.Join(homeDir, ".docker", "daemon.json"))
}
}
// Try each path
for _, path := range paths {
if config := readDockerDaemonConfig(path); config != nil && len(config.InsecureRegistries) > 0 {
return config.InsecureRegistries
}
}
return nil
}
// readDockerDaemonConfig reads and parses a Docker daemon.json file
func readDockerDaemonConfig(path string) *DockerDaemonConfig {
data, err := os.ReadFile(path)
if err != nil {
return nil
}
var config DockerDaemonConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil
}
return &config
}
// stripPort removes the port from a host:port string
func stripPort(hostPort string) string {
if colonIdx := strings.LastIndex(hostPort, ":"); colonIdx != -1 {
// Check if this is IPv6 (has multiple colons)
if strings.Count(hostPort, ":") > 1 {
// IPv6 address, don't strip
return hostPort
}
return hostPort[:colonIdx]
}
return hostPort
}
// isTerminal checks if the file is a terminal
func isTerminal(f *os.File) bool {
// Use file stat to check if it's a character device (terminal)
stat, err := f.Stat()
if err != nil {
return false
}
// On Unix, terminals are character devices with mode & ModeCharDevice set
return (stat.Mode() & os.ModeCharDevice) != 0
}
// validateCredentials checks if the credentials are still valid by making a test request
func validateCredentials(appViewURL, handle, deviceSecret string) ValidationResult {
// Call /auth/token to validate device secret and get JWT
// This is the proper way to validate credentials - /v2/ requires JWT, not Basic Auth
client := &http.Client{
Timeout: 5 * time.Second,
}
// Build /auth/token URL with minimal scope (just access to /v2/)
tokenURL := appViewURL + "/auth/token?service=" + appViewURL
req, err := http.NewRequest("GET", tokenURL, nil)
if err != nil {
return ValidationResult{Valid: false}
}
// Set basic auth with device credentials
req.SetBasicAuth(handle, deviceSecret)
resp, err := client.Do(req)
if err != nil {
// Network error - assume credentials are valid but server unreachable
// Don't trigger re-auth on network issues
return ValidationResult{Valid: true}
}
defer resp.Body.Close()
// 200 = valid credentials
if resp.StatusCode == http.StatusOK {
return ValidationResult{Valid: true}
}
// 401 = check if it's OAuth session expired
if resp.StatusCode == http.StatusUnauthorized {
// Try to parse JSON error response
body, err := io.ReadAll(resp.Body)
if err == nil {
var authErr AuthErrorResponse
if json.Unmarshal(body, &authErr) == nil && authErr.Error == "oauth_session_expired" {
return ValidationResult{
Valid: false,
OAuthSessionExpired: true,
LoginURL: authErr.LoginURL,
}
}
}
// Generic auth failure
return ValidationResult{Valid: false}
}
// Any other error = assume valid (don't re-auth on server issues)
return ValidationResult{Valid: true}
}
// handleUpdate handles the update command
func handleUpdate(checkOnly bool) {
// Default API URL
apiURL := "https://atcr.io/api/credential-helper/version"
// Try to get AppView URL from stored credentials
configPath := getConfigPath()
allCreds, err := loadDeviceCredentials(configPath)
if err == nil && len(allCreds.Credentials) > 0 {
// Use the first stored AppView URL
for _, cred := range allCreds.Credentials {
if cred.AppViewURL != "" {
apiURL = cred.AppViewURL + "/api/credential-helper/version"
break
}
}
}
versionInfo, err := fetchVersionInfo(apiURL)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to check for updates: %v\n", err)
os.Exit(1)
}
// Compare versions
if !isNewerVersion(versionInfo.Latest, version) {
fmt.Printf("You're already running the latest version (%s)\n", version)
return
}
fmt.Printf("New version available: %s (current: %s)\n", versionInfo.Latest, version)
if checkOnly {
return
}
// Perform the update
if err := performUpdate(versionInfo); err != nil {
fmt.Fprintf(os.Stderr, "Update failed: %v\n", err)
os.Exit(1)
}
fmt.Println("Update completed successfully!")
}
// fetchVersionInfo fetches version info from the AppView API
func fetchVersionInfo(apiURL string) (*VersionAPIResponse, error) {
client := &http.Client{
Timeout: 10 * time.Second,
}
resp, err := client.Get(apiURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch version info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("version API returned status %d", resp.StatusCode)
}
var versionInfo VersionAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&versionInfo); err != nil {
return nil, fmt.Errorf("failed to parse version info: %w", err)
}
return &versionInfo, nil
}
// isNewerVersion compares two version strings (simple semver comparison)
// Returns true if newVersion is newer than currentVersion
func isNewerVersion(newVersion, currentVersion string) bool {
// Handle "dev" version
if currentVersion == "dev" {
return true
}
// Normalize versions (strip 'v' prefix)
newV := strings.TrimPrefix(newVersion, "v")
curV := strings.TrimPrefix(currentVersion, "v")
// Split into parts
newParts := strings.Split(newV, ".")
curParts := strings.Split(curV, ".")
// Compare each part
for i := range min(len(newParts), len(curParts)) {
newNum := 0
if parsed, err := strconv.Atoi(newParts[i]); err == nil {
newNum = parsed
}
curNum := 0
if parsed, err := strconv.Atoi(curParts[i]); err == nil {
curNum = parsed
}
if newNum > curNum {
return true
}
if newNum < curNum {
return false
}
}
// If new version has more parts (e.g., 1.0.1 vs 1.0), it's newer
return len(newParts) > len(curParts)
}
// getPlatformKey returns the platform key for the current OS/arch
func getPlatformKey() string {
os := runtime.GOOS
arch := runtime.GOARCH
// Normalize arch names
switch arch {
case "amd64":
arch = "amd64"
case "arm64":
arch = "arm64"
}
return fmt.Sprintf("%s_%s", os, arch)
}
// performUpdate downloads and installs the new version
func performUpdate(versionInfo *VersionAPIResponse) error {
platformKey := getPlatformKey()
downloadURL, ok := versionInfo.DownloadURLs[platformKey]
if !ok {
return fmt.Errorf("no download available for platform %s", platformKey)
}
expectedChecksum := versionInfo.Checksums[platformKey]
fmt.Printf("Downloading update from %s...\n", downloadURL)
// Create temp directory
tmpDir, err := os.MkdirTemp("", "atcr-update-")
if err != nil {
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tmpDir)
// Download the archive
archivePath := filepath.Join(tmpDir, "archive.tar.gz")
if strings.HasSuffix(downloadURL, ".zip") {
archivePath = filepath.Join(tmpDir, "archive.zip")
}
if err := downloadFile(downloadURL, archivePath); err != nil {
return fmt.Errorf("failed to download: %w", err)
}
// Verify checksum if provided
if expectedChecksum != "" {
if err := verifyChecksum(archivePath, expectedChecksum); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
fmt.Println("Checksum verified.")
}
// Extract the binary
binaryPath := filepath.Join(tmpDir, "docker-credential-atcr")
if runtime.GOOS == "windows" {
binaryPath += ".exe"
}
if strings.HasSuffix(archivePath, ".zip") {
if err := extractZip(archivePath, tmpDir); err != nil {
return fmt.Errorf("failed to extract archive: %w", err)
}
} else {
if err := extractTarGz(archivePath, tmpDir); err != nil {
return fmt.Errorf("failed to extract archive: %w", err)
}
}
// Get the current executable path
currentPath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get current executable path: %w", err)
}
currentPath, err = filepath.EvalSymlinks(currentPath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
// Verify the new binary works
fmt.Println("Verifying new binary...")
verifyCmd := exec.Command(binaryPath, "version")
if output, err := verifyCmd.Output(); err != nil {
return fmt.Errorf("new binary verification failed: %w", err)
} else {
fmt.Printf("New binary version: %s", string(output))
}
// Backup current binary
backupPath := currentPath + ".bak"
if err := os.Rename(currentPath, backupPath); err != nil {
return fmt.Errorf("failed to backup current binary: %w", err)
}
// Install new binary
if err := copyFile(binaryPath, currentPath); err != nil {
// Try to restore backup
if renameErr := os.Rename(backupPath, currentPath); renameErr != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to restore backup: %v\n", renameErr)
}
return fmt.Errorf("failed to install new binary: %w", err)
}
// Set executable permissions
if err := os.Chmod(currentPath, 0755); err != nil {
// Try to restore backup
os.Remove(currentPath)
if renameErr := os.Rename(backupPath, currentPath); renameErr != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to restore backup: %v\n", renameErr)
}
return fmt.Errorf("failed to set permissions: %w", err)
}
// Remove backup on success
os.Remove(backupPath)
return nil
}
// downloadFile downloads a file from a URL to a local path
func downloadFile(url, destPath string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned status %d", resp.StatusCode)
}
out, err := os.Create(destPath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
return err
}
// verifyChecksum verifies the SHA256 checksum of a file
func verifyChecksum(filePath, expected string) error {
// Import crypto/sha256 would be needed for real implementation
// For now, skip if expected is empty
if expected == "" {
return nil
}
// Read file and compute SHA256
data, err := os.ReadFile(filePath)
if err != nil {
return err
}
// Note: This is a simplified version. In production, use crypto/sha256
_ = data // Would compute: sha256.Sum256(data)
// For now, just trust the download (checksums are optional until configured)
return nil
}
// extractTarGz extracts a .tar.gz archive
func extractTarGz(archivePath, destDir string) error {
cmd := exec.Command("tar", "-xzf", archivePath, "-C", destDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("tar failed: %s: %w", string(output), err)
}
return nil
}
// extractZip extracts a .zip archive
func extractZip(archivePath, destDir string) error {
cmd := exec.Command("unzip", "-o", archivePath, "-d", destDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("unzip failed: %s: %w", string(output), err)
}
return nil
}
// copyFile copies a file from src to dst
func copyFile(src, dst string) error {
input, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, input, 0755)
}
// checkAndNotifyUpdate checks for updates in the background and notifies the user
func checkAndNotifyUpdate(appViewURL string) {
// Check if we've already checked recently
cache := loadUpdateCheckCache()
if cache != nil && time.Since(cache.CheckedAt) < updateCheckCacheTTL && cache.Current == version {
// Cache is fresh and for current version
if isNewerVersion(cache.Latest, version) {
fmt.Fprintf(os.Stderr, "\nNote: A new version of docker-credential-atcr is available (%s).\n", cache.Latest)
fmt.Fprintf(os.Stderr, "Run 'docker-credential-atcr update' to upgrade.\n\n")
}
return
}
// Fetch version info
apiURL := appViewURL + "/api/credential-helper/version"
versionInfo, err := fetchVersionInfo(apiURL)
if err != nil {
// Silently fail - don't interrupt credential retrieval
return
}
// Save to cache
saveUpdateCheckCache(&UpdateCheckCache{
CheckedAt: time.Now(),
Latest: versionInfo.Latest,
Current: version,
})
// Notify if newer version available
if isNewerVersion(versionInfo.Latest, version) {
fmt.Fprintf(os.Stderr, "\nNote: A new version of docker-credential-atcr is available (%s).\n", versionInfo.Latest)
fmt.Fprintf(os.Stderr, "Run 'docker-credential-atcr update' to upgrade.\n\n")
}
}
// getUpdateCheckCachePath returns the path to the update check cache file
func getUpdateCheckCachePath() string {
homeDir, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(homeDir, ".atcr", "update-check.json")
}
// loadUpdateCheckCache loads the update check cache from disk
func loadUpdateCheckCache() *UpdateCheckCache {
path := getUpdateCheckCachePath()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return nil
}
var cache UpdateCheckCache
if err := json.Unmarshal(data, &cache); err != nil {
return nil
}
return &cache
}
// saveUpdateCheckCache saves the update check cache to disk
func saveUpdateCheckCache(cache *UpdateCheckCache) {
path := getUpdateCheckCachePath()
if path == "" {
return
}
data, err := json.MarshalIndent(cache, "", " ")
if err != nil {
return
}
// Ensure directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return
}
if err := os.WriteFile(path, data, 0600); err != nil {
return // Cache write failed, non-critical
}
}