diff --git a/age.go b/age.go index 1942db7..00c39f6 100644 --- a/age.go +++ b/age.go @@ -214,6 +214,7 @@ func (*NoIdentityMatchError) Error() string { // // It returns a Reader reading the decrypted plaintext of the age file read // from src. All identities will be tried until one successfully decrypts the file. +// Native, non-interactive identities are tried before any other identities. // // If no identity matches the encrypted file, the returned error will be of type // [NoIdentityMatchError]. @@ -240,6 +241,24 @@ func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) { if len(identities) == 0 { return nil, errors.New("no identities specified") } + slices.SortStableFunc(identities, func(a, b Identity) int { + var aIsNative, bIsNative bool + switch a.(type) { + case *X25519Identity, *HybridIdentity, *ScryptIdentity: + aIsNative = true + } + switch b.(type) { + case *X25519Identity, *HybridIdentity, *ScryptIdentity: + bIsNative = true + } + if aIsNative && !bIsNative { + return -1 + } + if !aIsNative && bIsNative { + return 1 + } + return 0 + }) stanzas := make([]*Stanza, 0, len(hdr.Recipients)) for _, s := range hdr.Recipients { diff --git a/age_test.go b/age_test.go index ef870d4..dfc753b 100644 --- a/age_test.go +++ b/age_test.go @@ -285,6 +285,50 @@ func TestLabels(t *testing.T) { } } +// testIdentity is a non-native identity that records if Unwrap is called. +type testIdentity struct { + called bool +} + +func (ti *testIdentity) Unwrap(stanzas []*age.Stanza) ([]byte, error) { + ti.called = true + return nil, age.ErrIncorrectIdentity +} + +func TestDecryptNativeIdentitiesFirst(t *testing.T) { + correct, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + unrelated, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := age.Encrypt(buf, correct.Recipient()) + if err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + nonNative := &testIdentity{} + + // Pass identities: unrelated native, non-native, correct native. + // Native identities should be tried first, so correct should match + // before nonNative is ever called. + _, err = age.Decrypt(bytes.NewReader(buf.Bytes()), unrelated, nonNative, correct) + if err != nil { + t.Fatal(err) + } + + if nonNative.called { + t.Error("non-native identity was called, but native identities should be tried first") + } +} + func TestDetachedHeader(t *testing.T) { i, err := age.GenerateX25519Identity() if err != nil {