diff --git a/cmd/age/age.go b/cmd/age/age.go index 15b8909..fb05255 100644 --- a/cmd/age/age.go +++ b/cmd/age/age.go @@ -59,6 +59,10 @@ func main() { log.Printf("Error: -a/--armor can't be used with -d/--decrypt.") log.Fatalf("Note that armored files are detected automatically.") } + if passFlag { + log.Printf("Error: -p/--passphrase can't be used with -d/--decrypt.") + log.Fatalf("Note that password protected files are detected automatically.") + } if len(recipientFlags) > 0 { log.Printf("Error: -r/--recipient can't be used with -d/--decrypt.") log.Fatalf("Did you mean to use -i/--identity to specify a private key?") @@ -111,24 +115,28 @@ func main() { } switch { - case passFlag: - fmt.Fprintf(os.Stderr, "Enter passphrase: ") - pass, err := readPassphrase() - if err != nil { - log.Fatalf("Error: could not read passphrase: %v", err) - } - if decryptFlag { - decryptPass(string(pass), in, out) - } else { - encryptPass(string(pass), in, out, armorFlag) - } case decryptFlag: - decryptKeys(identityFlags, in, out) + decrypt(identityFlags, in, out) + case passFlag: + pass, err := passphrasePrompt() + if err != nil { + log.Fatalf("Error: %v", err) + } + encryptPass(pass, in, out, armorFlag) default: encryptKeys(recipientFlags, in, out, armorFlag) } } +func passphrasePrompt() (string, error) { + fmt.Fprintf(os.Stderr, "Enter passphrase: ") + pass, err := readPassphrase() + if err != nil { + return "", fmt.Errorf("could not read passphrase: %v", err) + } + return string(pass), nil +} + func encryptKeys(keys []string, in io.Reader, out io.Writer, armor bool) { var recipients []age.Recipient for _, arg := range keys { @@ -166,8 +174,13 @@ func encrypt(recipients []age.Recipient, in io.Reader, out io.Writer, armor bool } } -func decryptKeys(keys []string, in io.Reader, out io.Writer) { - var identities []age.Identity +func decrypt(keys []string, in io.Reader, out io.Writer) { + identities := []age.Identity{ + // If there is an scrypt recipient (it will have to be the only one and) + // this identity will be invoked. + &LazyScryptIdentity{passphrasePrompt}, + } + // TODO: use the default location if no arguments are provided: // os.UserConfigDir()/age/keys.txt, ~/.ssh/id_rsa, ~/.ssh/id_ed25519 for _, name := range keys { @@ -177,18 +190,7 @@ func decryptKeys(keys []string, in io.Reader, out io.Writer) { } identities = append(identities, ids...) } - decrypt(identities, in, out) -} -func decryptPass(pass string, in io.Reader, out io.Writer) { - i, err := age.NewScryptIdentity(pass) - if err != nil { - log.Fatalf("Error: %v", err) - } - decrypt([]age.Identity{i}, in, out) -} - -func decrypt(identities []age.Identity, in io.Reader, out io.Writer) { r, err := age.Decrypt(in, identities...) if err != nil { log.Fatalf("Error: %v", err) diff --git a/cmd/age/encrypted_keys.go b/cmd/age/encrypted_keys.go index 74642e4..3bc160e 100644 --- a/cmd/age/encrypted_keys.go +++ b/cmd/age/encrypted_keys.go @@ -91,6 +91,28 @@ func (i *EncryptedSSHIdentity) Matches(block *format.Recipient) error { return nil } +type LazyScryptIdentity struct { + Passphrase func() (string, error) +} + +var _ age.Identity = &LazyScryptIdentity{} + +func (i *LazyScryptIdentity) Type() string { + return "scrypt" +} + +func (i *LazyScryptIdentity) Unwrap(block *format.Recipient) (fileKey []byte, err error) { + pass, err := i.Passphrase() + if err != nil { + return nil, fmt.Errorf("could not read passphrase: %v", err) + } + ii, err := age.NewScryptIdentity(pass) + if err != nil { + return nil, err + } + return ii.Unwrap(block) +} + // stdinInUse is set in main. It's a singleton like os.Stdin. var stdinInUse bool