diff --git a/age.go b/age.go index 948390f..52d4e36 100644 --- a/age.go +++ b/age.go @@ -51,10 +51,9 @@ import ( // An Identity is a private key or other value that can decrypt an opaque file // key from a recipient stanza. // -// Unwrap must return ErrIncorrectIdentity for recipient blocks that don't match -// the identity, any other error might be considered fatal. +// Unwrap must return ErrIncorrectIdentity for recipient stanzas that don't +// match the identity, any other error might be considered fatal. type Identity interface { - Type() string Unwrap(block *Stanza) (fileKey []byte, err error) } @@ -75,7 +74,6 @@ var ErrIncorrectIdentity = errors.New("incorrect identity for recipient block") // A Recipient is a public key or other value that can encrypt an opaque file // key to a recipient stanza. type Recipient interface { - Type() string Wrap(fileKey []byte) (*Stanza, error) } @@ -109,15 +107,15 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { hdr := &format.Header{} for i, r := range recipients { - if r.Type() == "scrypt" && len(recipients) != 1 { - return nil, errors.New("an scrypt recipient must be the only one") - } - block, err := r.Wrap(fileKey) if err != nil { return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err) } hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(block)) + + if block.Type == "scrypt" && len(recipients) != 1 { + return nil, errors.New("an scrypt recipient must be the only one") + } } if mac, err := headerMAC(fileKey, hdr); err != nil { return nil, fmt.Errorf("failed to compute header MAC: %v", err) @@ -163,10 +161,6 @@ RecipientsLoop: return nil, errors.New("an scrypt recipient must be the only one") } for _, i := range identities { - if i.Type() != r.Type { - continue - } - if i, ok := i.(IdentityMatcher); ok { err := i.Match((*Stanza)(r)) if err != nil { diff --git a/agessh/agessh.go b/agessh/agessh.go index 055cf43..dc457ab 100644 --- a/agessh/agessh.go +++ b/agessh/agessh.go @@ -48,8 +48,6 @@ type RSARecipient struct { var _ age.Recipient = &RSARecipient{} -func (*RSARecipient) Type() string { return "ssh-rsa" } - func NewRSARecipient(pk ssh.PublicKey) (*RSARecipient, error) { if pk.Type() != "ssh-rsa" { return nil, errors.New("SSH public key is not an RSA key") @@ -93,8 +91,6 @@ type RSAIdentity struct { var _ age.Identity = &RSAIdentity{} -func (*RSAIdentity) Type() string { return "ssh-rsa" } - func NewRSAIdentity(key *rsa.PrivateKey) (*RSAIdentity, error) { s, err := ssh.NewSignerFromKey(key) if err != nil { @@ -133,8 +129,6 @@ type Ed25519Recipient struct { var _ age.Recipient = &Ed25519Recipient{} -func (*Ed25519Recipient) Type() string { return "ssh-ed25519" } - func NewEd25519Recipient(pk ssh.PublicKey) (*Ed25519Recipient, error) { if pk.Type() != "ssh-ed25519" { return nil, errors.New("SSH public key is not an Ed25519 key") @@ -246,8 +240,6 @@ type Ed25519Identity struct { var _ age.Identity = &Ed25519Identity{} -func (*Ed25519Identity) Type() string { return "ssh-ed25519" } - func NewEd25519Identity(key ed25519.PrivateKey) (*Ed25519Identity, error) { s, err := ssh.NewSignerFromKey(key) if err != nil { diff --git a/agessh/agessh_test.go b/agessh/agessh_test.go index 97358c3..a0c25a9 100644 --- a/agessh/agessh_test.go +++ b/agessh/agessh_test.go @@ -37,10 +37,6 @@ func TestSSHRSARoundTrip(t *testing.T) { t.Fatal(err) } - if r.Type() != i.Type() || r.Type() != "ssh-rsa" { - t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) - } - fileKey := make([]byte, 16) if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) @@ -82,10 +78,6 @@ func TestSSHEd25519RoundTrip(t *testing.T) { t.Fatal(err) } - if r.Type() != i.Type() || r.Type() != "ssh-ed25519" { - t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) - } - fileKey := make([]byte, 16) if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) diff --git a/agessh/encrypted_keys.go b/agessh/encrypted_keys.go index 2e4f734..48454a4 100644 --- a/agessh/encrypted_keys.go +++ b/agessh/encrypted_keys.go @@ -56,11 +56,6 @@ func NewEncryptedSSHIdentity(pubKey ssh.PublicKey, pemBytes []byte, passphrase f var _ age.IdentityMatcher = &EncryptedSSHIdentity{} -// Type returns the type of the underlying private key, "ssh-ed25519" or "ssh-rsa". -func (i *EncryptedSSHIdentity) Type() string { - return i.pubKey.Type() -} - // Unwrap implements age.Identity. If the private key is still encrypted, it // will request the passphrase. The decrypted private key will be cached after // the first successful invocation. @@ -81,17 +76,20 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er switch k := k.(type) { case *ed25519.PrivateKey: i.decrypted, err = NewEd25519Identity(*k) + if i.pubKey.Type() != ssh.KeyAlgoED25519 { + return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", ssh.KeyAlgoED25519, i.pubKey.Type()) + } case *rsa.PrivateKey: i.decrypted, err = NewRSAIdentity(k) + if i.pubKey.Type() != ssh.KeyAlgoRSA { + return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", ssh.KeyAlgoRSA, i.pubKey.Type()) + } default: return nil, fmt.Errorf("unexpected SSH key type: %T", k) } if err != nil { return nil, fmt.Errorf("invalid SSH key: %v", err) } - if i.decrypted.Type() != i.pubKey.Type() { - return nil, fmt.Errorf("mismatched SSH key type: got %q, expected %q", i.decrypted.Type(), i.pubKey.Type()) - } return i.decrypted.Unwrap(block) } @@ -99,11 +97,11 @@ func (i *EncryptedSSHIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err er // Match implements age.IdentityMatcher without decrypting the private key, to // ensure the passphrase is only obtained if necessary. func (i *EncryptedSSHIdentity) Match(block *age.Stanza) error { - if block.Type != i.Type() { + if block.Type != i.pubKey.Type() { return age.ErrIncorrectIdentity } if len(block.Args) < 1 { - return fmt.Errorf("invalid %v recipient block", i.Type()) + return fmt.Errorf("invalid %v recipient block", i.pubKey.Type()) } if block.Args[0] != sshFingerprint(i.pubKey) { diff --git a/cmd/age/encrypted_keys.go b/cmd/age/encrypted_keys.go index 6783170..0be899e 100644 --- a/cmd/age/encrypted_keys.go +++ b/cmd/age/encrypted_keys.go @@ -20,11 +20,10 @@ type LazyScryptIdentity struct { var _ age.Identity = &LazyScryptIdentity{} -func (i *LazyScryptIdentity) Type() string { - return "scrypt" -} - func (i *LazyScryptIdentity) Unwrap(block *age.Stanza) (fileKey []byte, err error) { + if block.Type != "scrypt" { + return nil, age.ErrIncorrectIdentity + } pass, err := i.Passphrase() if err != nil { return nil, fmt.Errorf("could not read passphrase: %v", err) diff --git a/recipients_test.go b/recipients_test.go index 1b9436b..462e2be 100644 --- a/recipients_test.go +++ b/recipients_test.go @@ -22,10 +22,6 @@ func TestX25519RoundTrip(t *testing.T) { } r := i.Recipient() - if r.Type() != i.Type() || r.Type() != "X25519" { - t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) - } - if r1, err := age.ParseX25519Recipient(r.String()); err != nil { t.Fatal(err) } else if r1.String() != r.String() { @@ -72,10 +68,6 @@ func TestScryptRoundTrip(t *testing.T) { t.Fatal(err) } - if r.Type() != i.Type() || r.Type() != "scrypt" { - t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type()) - } - fileKey := make([]byte, 16) if _, err := rand.Read(fileKey); err != nil { t.Fatal(err) diff --git a/scrypt.go b/scrypt.go index 3d71025..25bc42e 100644 --- a/scrypt.go +++ b/scrypt.go @@ -35,8 +35,6 @@ type ScryptRecipient struct { var _ Recipient = &ScryptRecipient{} -func (*ScryptRecipient) Type() string { return "scrypt" } - // NewScryptRecipient returns a new ScryptRecipient with the provided password. func NewScryptRecipient(password string) (*ScryptRecipient, error) { if len(password) == 0 { @@ -98,8 +96,6 @@ type ScryptIdentity struct { var _ Identity = &ScryptIdentity{} -func (*ScryptIdentity) Type() string { return "scrypt" } - // NewScryptIdentity returns a new ScryptIdentity with the provided password. func NewScryptIdentity(password string) (*ScryptIdentity, error) { if len(password) == 0 { diff --git a/x25519.go b/x25519.go index 4362694..1a363a9 100644 --- a/x25519.go +++ b/x25519.go @@ -34,8 +34,6 @@ type X25519Recipient struct { var _ Recipient = &X25519Recipient{} -func (*X25519Recipient) Type() string { return "X25519" } - // newX25519RecipientFromPoint returns a new X25519Recipient from a raw Curve25519 point. func newX25519RecipientFromPoint(publicKey []byte) (*X25519Recipient, error) { if len(publicKey) != curve25519.PointSize { @@ -117,8 +115,6 @@ type X25519Identity struct { var _ Identity = &X25519Identity{} -func (*X25519Identity) Type() string { return "X25519" } - // newX25519IdentityFromScalar returns a new X25519Identity from a raw Curve25519 scalar. func newX25519IdentityFromScalar(secretKey []byte) (*X25519Identity, error) { if len(secretKey) != curve25519.ScalarSize {