diff --git a/cmd/age/age.go b/cmd/age/age.go index 24b4309..0a45c3b 100644 --- a/cmd/age/age.go +++ b/cmd/age/age.go @@ -8,6 +8,7 @@ package main import ( "flag" + "fmt" "io" "log" "os" @@ -15,61 +16,95 @@ import ( "github.com/FiloSottile/age/internal/age" ) +type multiFlag []string + +func (f *multiFlag) String() string { return fmt.Sprint(*f) } + +func (f *multiFlag) Set(value string) error { + *f = append(*f, value) + return nil +} + func main() { log.SetFlags(0) - decryptFlag := flag.Bool("d", false, "decrypt the input") - outFlag := flag.String("o", "", "output to `FILE` (default stdout)") - inFlag := flag.String("i", "", "read from `FILE` (default stdin)") - armorFlag := flag.Bool("a", false, "generate an armored file") + var ( + outFlag string + decryptFlag, armorFlag bool + recipientFlags, identityFlags multiFlag + ) + + flag.BoolVar(&decryptFlag, "d", false, "decrypt the input") + flag.BoolVar(&decryptFlag, "decrypt", false, "decrypt the input") + flag.StringVar(&outFlag, "o", "", "output to `FILE` (default stdout)") + flag.BoolVar(&armorFlag, "a", false, "generate an armored file") + flag.BoolVar(&armorFlag, "armor", false, "generate an armored file") + flag.Var(&recipientFlags, "r", "recipient (can be repeated)") + flag.Var(&recipientFlags, "recipient", "recipient (can be repeated)") + flag.Var(&identityFlags, "i", "identity (can be repeated)") + flag.Var(&identityFlags, "identity", "identity (can be repeated)") flag.Parse() + if flag.NArg() > 1 { + log.Printf("Error: too many arguments.") + log.Fatalf("age accepts a single optional argument for the input file.") + } switch { - case *decryptFlag: - if *armorFlag { - log.Fatalf("Invalid flag combination") + case decryptFlag: + if armorFlag { + log.Printf("Error: -a/--armor can't be used with -d/--decrypt.") + log.Fatalf("Note that armored files are detected automatically.") + } + if len(recipientFlags) > 0 { + log.Printf("Error: -r/--recipient can't be used with -d/--decrypt.") + log.Fatalf("Did you mean to use -i/--identity to specify a private key?") } default: // encrypt + if len(identityFlags) > 0 { + log.Printf("Error: -i/--identity can't be used in encryption mode.") + log.Fatalf("Did you forget to specify -d/--decrypt?") + } + if len(recipientFlags) == 0 { + log.Printf("Error: missing recipients.") + log.Fatalf("Did you forget to specify -r/--recipient?") + } } in, out := os.Stdin, os.Stdout - if name := *inFlag; name != "" { + if name := flag.Arg(0); name != "" && name != "-" { f, err := os.Open(name) if err != nil { - log.Fatalf("Failed to open input file %q: %v", name, err) + log.Fatalf("Error: failed to open input file %q: %v", name, err) } defer f.Close() in = f } - if name := *outFlag; name != "" { + if name := outFlag; name != "" && name != "-" { f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) if err != nil { - log.Fatalf("Failed to open output file %q: %v", name, err) + log.Fatalf("Error: failed to open output file %q: %v", name, err) } defer f.Close() out = f } switch { - case *decryptFlag: - decrypt(in, out) + case decryptFlag: + decrypt(identityFlags, in, out) default: - encrypt(in, out, *armorFlag) + encrypt(recipientFlags, in, out, armorFlag) } } -func encrypt(in io.Reader, out io.Writer, armor bool) { +func encrypt(args []string, in io.Reader, out io.Writer, armor bool) { var recipients []age.Recipient - for _, arg := range flag.Args() { + for _, arg := range args { r, err := parseRecipient(arg) if err != nil { log.Fatalf("Error: %v", err) } recipients = append(recipients, r) } - if len(recipients) == 0 { - log.Fatalf("Missing recipients!") - } ageEncrypt := age.Encrypt if armor { @@ -77,21 +112,21 @@ func encrypt(in io.Reader, out io.Writer, armor bool) { } w, err := ageEncrypt(out, recipients...) if err != nil { - log.Fatalf("Error initializing encryption: %v", err) + log.Fatalf("Error: %v", err) } if _, err := io.Copy(w, in); err != nil { - log.Fatalf("Error encrypting the input: %v", err) + log.Fatalf("Error: %v", err) } if err := w.Close(); err != nil { - log.Fatalf("Error finalizing encryption: %v", err) + log.Fatalf("Error: %v", err) } } -func decrypt(in io.Reader, out io.Writer) { +func decrypt(args []string, in io.Reader, out io.Writer) { var identities []age.Identity // TODO: use the default location if no arguments are provided: // os.UserConfigDir()/age/keys.txt, ~/.ssh/id_rsa, ~/.ssh/id_ed25519 - for _, name := range flag.Args() { + for _, name := range args { ids, err := parseIdentitiesFile(name) if err != nil { log.Fatalf("Error: %v", err) @@ -101,9 +136,9 @@ func decrypt(in io.Reader, out io.Writer) { r, err := age.Decrypt(in, identities...) if err != nil { - log.Fatalf("Error initializing decryption: %v", err) + log.Fatalf("Error: %v", err) } if _, err := io.Copy(out, r); err != nil { - log.Fatalf("Error decrypting the input: %v", err) + log.Fatalf("Error: %v", err) } }