diff --git a/age.go b/age.go index eafe090..d2a264d 100644 --- a/age.go +++ b/age.go @@ -209,7 +209,10 @@ type NoIdentityMatchError struct { StanzaTypes []string } -func (*NoIdentityMatchError) Error() string { +func (e *NoIdentityMatchError) Error() string { + if len(e.Errors) == 1 { + return "identity did not match any of the recipients: " + e.Errors[0].Error() + } return "no identity matched any of the recipients" } diff --git a/age_test.go b/age_test.go index da65e3a..48d6034 100644 --- a/age_test.go +++ b/age_test.go @@ -379,6 +379,71 @@ func TestNoIdentityMatchErrorStanzaTypes(t *testing.T) { } } +func TestScryptIdentityErrors(t *testing.T) { + t.Run("not passphrase-encrypted", func(t *testing.T) { + i, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, i.Recipient()) + if err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + scryptID, err := age.NewScryptIdentity("password") + if err != nil { + t.Fatal(err) + } + _, err = age.Decrypt(bytes.NewReader(buf.Bytes()), scryptID) + if err == nil { + t.Fatal("expected decryption to fail") + } + if !errors.Is(err, age.ErrIncorrectIdentity) { + t.Errorf("expected ErrIncorrectIdentity, got %v", err) + } + if !strings.Contains(err.Error(), "not passphrase-encrypted") { + t.Errorf("expected error to mention 'not passphrase-encrypted', got %v", err) + } + }) + + t.Run("incorrect passphrase", func(t *testing.T) { + r, err := age.NewScryptRecipient("correct-password") + if err != nil { + t.Fatal(err) + } + r.SetWorkFactor(10) // Low for fast test + + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, r) + if err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + scryptID, err := age.NewScryptIdentity("wrong-password") + if err != nil { + t.Fatal(err) + } + _, err = age.Decrypt(bytes.NewReader(buf.Bytes()), scryptID) + if err == nil { + t.Fatal("expected decryption to fail") + } + if !errors.Is(err, age.ErrIncorrectIdentity) { + t.Errorf("expected ErrIncorrectIdentity, got %v", err) + } + if !strings.Contains(err.Error(), "incorrect passphrase") { + t.Errorf("expected error to mention 'incorrect passphrase', got %v", err) + } + }) +} + func TestDetachedHeader(t *testing.T) { i, err := age.GenerateX25519Identity() if err != nil { diff --git a/scrypt.go b/scrypt.go index 0ed2859..46bbc21 100644 --- a/scrypt.go +++ b/scrypt.go @@ -150,14 +150,20 @@ func (i *ScryptIdentity) Unwrap(stanzas []*Stanza) ([]byte, error) { return nil, errors.New("an scrypt recipient must be the only one") } } - return multiUnwrap(i.unwrap, stanzas) + for _, s := range stanzas { + if s.Type != "scrypt" { + continue + } + return i.unwrap(s) + } + return nil, fmt.Errorf("%w: file is not passphrase-encrypted", ErrIncorrectIdentity) } var digitsRe = regexp.MustCompile(`^[1-9][0-9]*$`) func (i *ScryptIdentity) unwrap(block *Stanza) ([]byte, error) { if block.Type != "scrypt" { - return nil, ErrIncorrectIdentity + return nil, errors.New("internal error: unwrap called on non-scrypt stanza") } if len(block.Args) != 2 { return nil, errors.New("invalid scrypt recipient block") @@ -200,7 +206,9 @@ func (i *ScryptIdentity) unwrap(block *Stanza) ([]byte, error) { if err == errIncorrectCiphertextSize { return nil, errors.New("invalid scrypt recipient block: incorrect file key size") } else if err != nil { - return nil, ErrIncorrectIdentity + // Wrap [ErrIncorrectIdentity] so that multiple passphrases can be tried + // in sequence by passing multiple [ScryptIdentity] values to [Decrypt]. + return nil, fmt.Errorf("%w: incorrect passphrase", ErrIncorrectIdentity) } return fileKey, nil }