diff --git a/cmd/age/age.go b/cmd/age/age.go index 43fa119..03be199 100644 --- a/cmd/age/age.go +++ b/cmd/age/age.go @@ -23,25 +23,51 @@ func main() { generateFlag := flag.Bool("generate", false, "generate a new age key pair") decryptFlag := flag.Bool("d", false, "decrypt the input") + outFlag := flag.String("o", "", "output to `FILE` (default stdout)") + inFlag := flag.String("i", "", "read from `FILE` (default stdin)") flag.Parse() switch { case *generateFlag: - if *decryptFlag { + if *decryptFlag || *inFlag != "" { log.Fatalf("Invalid flag combination") } - generate() case *decryptFlag: if *generateFlag { log.Fatalf("Invalid flag combination") } - decrypt() + default: // encrypt + } + + in, out := os.Stdin, os.Stdout + if name := *inFlag; name != "" { + f, err := os.Open(name) + if err != nil { + log.Fatalf("Failed to open input file %q: %v", name, err) + } + defer f.Close() + in = f + } + if name := *outFlag; name != "" { + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + if err != nil { + log.Fatalf("Failed to open output file %q: %v", name, err) + } + defer f.Close() + out = f + } + + switch { + case *generateFlag: + generate(out) + case *decryptFlag: + decrypt(in, out) default: - encrypt() + encrypt(in, out) } } -func generate() { +func generate(out io.Writer) { if len(flag.Args()) != 0 { log.Fatalf("-generate takes no arguments") } @@ -51,12 +77,12 @@ func generate() { log.Fatalf("Internal error: %v", err) } - fmt.Printf("# created: %s\n", time.Now().Format(time.RFC3339)) - fmt.Printf("# %s\n", k.Recipient()) - fmt.Printf("%s\n", k) + fmt.Fprintf(out, "# created: %s\n", time.Now().Format(time.RFC3339)) + fmt.Fprintf(out, "# %s\n", k.Recipient()) + fmt.Fprintf(out, "%s\n", k) } -func encrypt() { +func encrypt(in io.Reader, out io.Writer) { var recipients []age.Recipient for _, arg := range flag.Args() { r, err := parseRecipient(arg) @@ -69,11 +95,11 @@ func encrypt() { log.Fatalf("Missing recipients!") } - w, err := age.Encrypt(os.Stdout, recipients...) + w, err := age.Encrypt(out, recipients...) if err != nil { log.Fatalf("Error initializing encryption: %v", err) } - if _, err := io.Copy(w, os.Stdin); err != nil { + if _, err := io.Copy(w, in); err != nil { log.Fatalf("Error encrypting the input: %v", err) } if err := w.Close(); err != nil { @@ -81,7 +107,7 @@ func encrypt() { } } -func decrypt() { +func decrypt(in io.Reader, out io.Writer) { var identities []age.Identity // TODO: use the default location if no arguments are provided. for _, name := range flag.Args() { @@ -102,11 +128,11 @@ func decrypt() { identities = append(identities, ids...) } - r, err := age.Decrypt(os.Stdin, identities...) + r, err := age.Decrypt(in, identities...) if err != nil { log.Fatalf("Error initializing decryption: %v", err) } - if _, err := io.Copy(os.Stdout, r); err != nil { + if _, err := io.Copy(out, r); err != nil { log.Fatalf("Error decrypting the input: %v", err) } } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 05ed0ec..29a7fb7 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -178,7 +178,6 @@ func (w *Writer) Write(p []byte) (n int, err error) { } func (w *Writer) Close() error { - // TODO: close w.dst if it can be interface upgraded to io.Closer. if w.err != nil { return w.err }