diff --git a/cmd/age/parse.go b/cmd/age/parse.go index 4a59e7a..2707484 100644 --- a/cmd/age/parse.go +++ b/cmd/age/parse.go @@ -16,6 +16,7 @@ import ( "filippo.io/age/agessh" "filippo.io/age/armor" "filippo.io/age/plugin" + "filippo.io/age/tag" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/ssh" ) @@ -30,6 +31,8 @@ func (gitHubRecipientError) Error() string { func parseRecipient(arg string) (age.Recipient, error) { switch { + case strings.HasPrefix(arg, "age1tag1"): + return tag.ParseRecipient(arg) case strings.HasPrefix(arg, "age1") && strings.Count(arg, "1") > 1: return plugin.NewRecipient(arg, pluginTerminalUI) case strings.HasPrefix(arg, "age1"): diff --git a/go.mod b/go.mod index 31df01f..7c01486 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( filippo.io/edwards25519 v1.1.0 + filippo.io/nistec v0.0.3 golang.org/x/crypto v0.24.0 golang.org/x/sys v0.21.0 golang.org/x/term v0.21.0 diff --git a/go.sum b/go.sum index fd0f776..b3b4202 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805 h1:u2qwJeEvnypw+OCPUHmoZE3I c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805/go.mod h1:FomMrUJ2Lxt5jCLmZkG3FHa72zUprnhd3v/Z18Snm4w= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/nistec v0.0.3 h1:h336Je2jRDZdBCLy2fLDUd9E2unG32JLwcJi0JQE9Cw= +filippo.io/nistec v0.0.3/go.mod h1:84fxC9mi+MhC2AERXI4LSa8cmSVOzrFikg6hZ4IfCyw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= diff --git a/tag/internal/hpke/hpke.go b/tag/internal/hpke/hpke.go new file mode 100644 index 0000000..6d2bed7 --- /dev/null +++ b/tag/internal/hpke/hpke.go @@ -0,0 +1,346 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hpke + +import ( + "crypto/cipher" + "crypto/ecdh" + "crypto/hkdf" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "hash" + "math/bits" + + "golang.org/x/crypto/chacha20poly1305" +) + +type KEMSender interface { + Encap() (sharedSecret, enc []byte, err error) + ID() uint16 +} + +type KEMRecipient interface { + Decap(enc []byte) (sharedSecret []byte, err error) + ID() uint16 +} + +type dhKEM struct { + kdf KDF + id uint16 + nSecret uint16 +} + +func (dh *dhKEM) extractAndExpand(dhKey, kemContext []byte) ([]byte, error) { + suiteID := binary.BigEndian.AppendUint16([]byte("KEM"), dh.id) + eaePRK, err := dh.kdf.LabeledExtract(suiteID, nil, "eae_prk", dhKey) + if err != nil { + return nil, err + } + return dh.kdf.LabeledExpand(suiteID, eaePRK, "shared_secret", kemContext, dh.nSecret) +} + +func (dh *dhKEM) ID() uint16 { + return dh.id +} + +type dhkemSender struct { + dhKEM + pub *ecdh.PublicKey +} + +// DHKEMSender returns a KEMSender implementing DHKEM(P-256, HKDF-SHA256). +func DHKEMSender(pub *ecdh.PublicKey) (KEMSender, error) { + switch pub.Curve() { + case ecdh.P256(): + return &dhkemSender{ + pub: pub, + dhKEM: dhKEM{ + kdf: HKDFSHA256(), + id: 0x0010, + nSecret: 32, + }, + }, nil + default: + return nil, errors.New("unsupported curve") + } +} + +// testingOnlyGenerateKey is only used during testing, to provide +// a fixed test key to use when checking the RFC 9180 vectors. +var testingOnlyGenerateKey func() *ecdh.PrivateKey + +func (dh *dhkemSender) Encap() (sharedSecret []byte, encapPub []byte, err error) { + privEph, err := dh.pub.Curve().GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + if testingOnlyGenerateKey != nil { + privEph = testingOnlyGenerateKey() + } + dhVal, err := privEph.ECDH(dh.pub) + if err != nil { + return nil, nil, err + } + encPubEph := privEph.PublicKey().Bytes() + + encPubRecip := dh.pub.Bytes() + kemContext := append(encPubEph, encPubRecip...) + sharedSecret, err = dh.extractAndExpand(dhVal, kemContext) + if err != nil { + return nil, nil, err + } + return sharedSecret, encPubEph, nil +} + +type dhkemRecipient struct { + dhKEM + priv *ecdh.PrivateKey +} + +// DHKEMRecipient returns a KEMRecipient implementing DHKEM(P-256, HKDF-SHA256). +func DHKEMRecipient(priv *ecdh.PrivateKey) (KEMRecipient, error) { + switch priv.Curve() { + case ecdh.P256(): + return &dhkemRecipient{ + priv: priv, + dhKEM: dhKEM{ + kdf: HKDFSHA256(), + id: 0x0010, + nSecret: 32, + }, + }, nil + default: + return nil, errors.New("unsupported curve") + } +} + +func (dh *dhkemRecipient) Decap(encPubEph []byte) ([]byte, error) { + pubEph, err := dh.priv.Curve().NewPublicKey(encPubEph) + if err != nil { + return nil, err + } + dhVal, err := dh.priv.ECDH(pubEph) + if err != nil { + return nil, err + } + kemContext := append(encPubEph, dh.priv.PublicKey().Bytes()...) + return dh.extractAndExpand(dhVal, kemContext) +} + +type KDF interface { + LabeledExtract(sid, salt []byte, label string, inputKey []byte) ([]byte, error) + LabeledExpand(suiteID, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) + ID() uint16 +} + +type hkdfKDF struct { + hash func() hash.Hash + id uint16 +} + +func HKDFSHA256() KDF { + return &hkdfKDF{hash: sha256.New, id: 0x0001} +} + +func (kdf *hkdfKDF) ID() uint16 { + return kdf.id +} + +func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) { + labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) + labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) + labeledIKM = append(labeledIKM, sid...) + labeledIKM = append(labeledIKM, label...) + labeledIKM = append(labeledIKM, inputKey...) + return hkdf.Extract(kdf.hash, labeledIKM, salt) +} + +func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) { + labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) + labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length) + labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) + labeledInfo = append(labeledInfo, suiteID...) + labeledInfo = append(labeledInfo, label...) + labeledInfo = append(labeledInfo, info...) + return hkdf.Expand(kdf.hash, randomKey, string(labeledInfo), int(length)) +} + +type AEAD interface { + AEAD(key []byte) (cipher.AEAD, error) + KeySize() int + NonceSize() int + ID() uint16 +} + +type aead struct { + keySize int + nonceSize int + aead func([]byte) (cipher.AEAD, error) + id uint16 +} + +func ChaCha20Poly1305() AEAD { + return &aead{ + keySize: chacha20poly1305.KeySize, + nonceSize: chacha20poly1305.NonceSize, + aead: chacha20poly1305.New, + id: 0x0003, + } +} + +func (a *aead) ID() uint16 { + return a.id +} + +func (a *aead) AEAD(key []byte) (cipher.AEAD, error) { + if len(key) != a.keySize { + return nil, errors.New("invalid key size") + } + return a.aead(key) +} + +func (a *aead) KeySize() int { + return a.keySize +} + +func (a *aead) NonceSize() int { + return a.nonceSize +} + +type context struct { + aead cipher.AEAD + suiteID []byte + + key []byte + baseNonce []byte + + seqNum uint128 +} + +type Sender struct { + *context +} + +type Recipient struct { + *context +} + +func newContext(sharedSecret []byte, kemID uint16, kdf KDF, aead AEAD, info []byte) (*context, error) { + sid := suiteID(kemID, kdf.ID(), aead.ID()) + + pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) + if err != nil { + return nil, err + } + infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info) + if err != nil { + return nil, err + } + ksContext := append([]byte{0}, pskIDHash...) + ksContext = append(ksContext, infoHash...) + + secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) + if err != nil { + return nil, err + } + key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aead.KeySize())) + if err != nil { + return nil, err + } + baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aead.NonceSize())) + if err != nil { + return nil, err + } + + a, err := aead.AEAD(key) + if err != nil { + return nil, err + } + + return &context{ + aead: a, + suiteID: sid, + key: key, + baseNonce: baseNonce, + }, nil +} + +func SetupSender(kem KEMSender, kdf KDF, aead AEAD, info []byte) ([]byte, *Sender, error) { + sharedSecret, encapsulatedKey, err := kem.Encap() + if err != nil { + return nil, nil, err + } + context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info) + if err != nil { + return nil, nil, err + } + return encapsulatedKey, &Sender{context}, nil +} + +func SetupRecipient(kem KEMRecipient, kdf KDF, aead AEAD, info, enc []byte) (*Recipient, error) { + sharedSecret, err := kem.Decap(enc) + if err != nil { + return nil, err + } + context, err := newContext(sharedSecret, kem.ID(), kdf, aead, info) + if err != nil { + return nil, err + } + return &Recipient{context}, nil +} + +func (ctx *context) nextNonce() []byte { + nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():] + for i := range ctx.baseNonce { + nonce[i] ^= ctx.baseNonce[i] + } + return nonce +} + +func (ctx *context) incrementNonce() { + ctx.seqNum = ctx.seqNum.addOne() +} + +func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { + ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) + s.incrementNonce() + return ciphertext, nil +} + +func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) { + plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad) + if err != nil { + return nil, err + } + r.incrementNonce() + return plaintext, nil +} + +func suiteID(kemID, kdfID, aeadID uint16) []byte { + suiteID := make([]byte, 0, 4+2+2+2) + suiteID = append(suiteID, []byte("HPKE")...) + suiteID = binary.BigEndian.AppendUint16(suiteID, kemID) + suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID) + suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID) + return suiteID +} + +type uint128 struct { + hi, lo uint64 +} + +func (u uint128) addOne() uint128 { + lo, carry := bits.Add64(u.lo, 1, 0) + return uint128{u.hi + carry, lo} +} + +func (u uint128) bytes() []byte { + b := make([]byte, 16) + binary.BigEndian.PutUint64(b[0:], u.hi) + binary.BigEndian.PutUint64(b[8:], u.lo) + return b +} diff --git a/tag/internal/hpke/hpke_test.go b/tag/internal/hpke/hpke_test.go new file mode 100644 index 0000000..d15a657 --- /dev/null +++ b/tag/internal/hpke/hpke_test.go @@ -0,0 +1,218 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hpke + +import ( + "bytes" + "crypto/ecdh" + "encoding/hex" + "encoding/json" + "errors" + "os" + "strconv" + "strings" + "testing" +) + +func mustDecodeHex(t *testing.T, in string) []byte { + t.Helper() + b, err := hex.DecodeString(in) + if err != nil { + t.Fatal(err) + } + return b +} + +func parseVectorSetup(vector string) map[string]string { + vals := map[string]string{} + for _, l := range strings.Split(vector, "\n") { + fields := strings.Split(l, ": ") + vals[fields[0]] = fields[1] + } + return vals +} + +func parseVectorEncryptions(vector string) []map[string]string { + vals := []map[string]string{} + for _, section := range strings.Split(vector, "\n\n") { + e := map[string]string{} + for _, l := range strings.Split(section, "\n") { + fields := strings.Split(l, ": ") + e[fields[0]] = fields[1] + } + vals = append(vals, e) + } + return vals +} + +func TestRFC9180Vectors(t *testing.T) { + vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json") + if err != nil { + t.Fatal(err) + } + + var vectors []struct { + Name string + Setup string + Encryptions string + } + if err := json.Unmarshal(vectorsJSON, &vectors); err != nil { + t.Fatal(err) + } + + for _, vector := range vectors { + t.Run(vector.Name, func(t *testing.T) { + setup := parseVectorSetup(vector.Setup) + + kemID, err := strconv.Atoi(setup["kem_id"]) + if err != nil { + t.Fatal(err) + } + kdfID, err := strconv.Atoi(setup["kdf_id"]) + if err != nil { + t.Fatal(err) + } + aeadID, err := strconv.Atoi(setup["aead_id"]) + if err != nil { + t.Fatal(err) + } + info := mustDecodeHex(t, setup["info"]) + pubKeyBytes := mustDecodeHex(t, setup["pkRm"]) + pub, err := parsePublicKey(uint16(kemID), pubKeyBytes) + if err != nil { + t.Fatal(err) + } + + ephemeralPrivKey := mustDecodeHex(t, setup["skEm"]) + + testingOnlyGenerateKey = func() *ecdh.PrivateKey { + priv, err := parsePrivateKey(uint16(kemID), ephemeralPrivKey) + if err != nil { + t.Fatal(err) + } + return priv + } + t.Cleanup(func() { testingOnlyGenerateKey = nil }) + + kemSender, err := DHKEMSender(pub) + if err != nil { + t.Fatal(err) + } + kdf, err := getKDF(uint16(kdfID)) + if err != nil { + t.Fatal(err) + } + aead, err := getAEAD(uint16(aeadID)) + if err != nil { + t.Fatal(err) + } + encap, sender, err := SetupSender(kemSender, kdf, aead, info) + if err != nil { + t.Fatal(err) + } + + expectedEncap := mustDecodeHex(t, setup["enc"]) + if !bytes.Equal(encap, expectedEncap) { + t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap) + } + + privKeyBytes := mustDecodeHex(t, setup["skRm"]) + priv, err := parsePrivateKey(uint16(kemID), privKeyBytes) + if err != nil { + t.Fatal(err) + } + + kemRecipient, err := DHKEMRecipient(priv) + if err != nil { + t.Fatal(err) + } + recipient, err := SetupRecipient(kemRecipient, kdf, aead, info, encap) + if err != nil { + t.Fatal(err) + } + + for _, ctx := range []*context{sender.context, recipient.context} { + expectedKey := mustDecodeHex(t, setup["key"]) + if !bytes.Equal(ctx.key, expectedKey) { + t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey) + } + expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"]) + if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) { + t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce) + } + } + + for _, enc := range parseVectorEncryptions(vector.Encryptions) { + t.Run("seq num "+enc["sequence number"], func(t *testing.T) { + seqNum, err := strconv.Atoi(enc["sequence number"]) + if err != nil { + t.Fatal(err) + } + sender.seqNum = uint128{lo: uint64(seqNum)} + recipient.seqNum = uint128{lo: uint64(seqNum)} + expectedNonce := mustDecodeHex(t, enc["nonce"]) + computedNonce := sender.nextNonce() + if !bytes.Equal(computedNonce, expectedNonce) { + t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce) + } + + expectedCiphertext := mustDecodeHex(t, enc["ct"]) + ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"])) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(ciphertext, expectedCiphertext) { + t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext) + } + + expectedPlaintext := mustDecodeHex(t, enc["pt"]) + plaintext, err := recipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"])) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, expectedPlaintext) { + t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext) + } + }) + } + }) + } +} + +func parsePublicKey(kemID uint16, keyBytes []byte) (*ecdh.PublicKey, error) { + switch kemID { + case 0x0010: // DHKEM(P-256, HKDF-SHA256) + return ecdh.P256().NewPublicKey(keyBytes) + default: + return nil, errors.New("unsupported KEM") + } +} + +func parsePrivateKey(kemID uint16, keyBytes []byte) (*ecdh.PrivateKey, error) { + switch kemID { + case 0x0010: // DHKEM(P-256, HKDF-SHA256) + return ecdh.P256().NewPrivateKey(keyBytes) + default: + return nil, errors.New("unsupported KEM") + } +} + +func getKDF(kdfID uint16) (KDF, error) { + switch kdfID { + case 0x0001: // HKDF-SHA256 + return HKDFSHA256(), nil + default: + return nil, errors.New("unsupported KDF") + } +} + +func getAEAD(aeadID uint16) (AEAD, error) { + switch aeadID { + case 0x0003: // ChaCha20Poly1305 + return ChaCha20Poly1305(), nil + default: + return nil, errors.New("unsupported AEAD") + } +} diff --git a/tag/internal/hpke/testdata/rfc9180-vectors.json b/tag/internal/hpke/testdata/rfc9180-vectors.json new file mode 100644 index 0000000..44dc418 --- /dev/null +++ b/tag/internal/hpke/testdata/rfc9180-vectors.json @@ -0,0 +1,7 @@ +[ + { + "Name": "DHKEM(P-256, HKDF-SHA256), HKDF-SHA256, ChaCha20Poly1305", + "Setup": "mode: 0\nkem_id: 16\nkdf_id: 1\naead_id: 3\ninfo: 4f6465206f6e2061204772656369616e2055726e\nikmE: f1f1a3bc95416871539ecb51c3a8f0cf608afb40fbbe305c0a72819d35c33f1f\npkEm: 04c07836a0206e04e31d8ae99bfd549380b072a1b1b82e563c935c095827824fc1559eac6fb9e3c70cd3193968994e7fe9781aa103f5b50e934b5b2f387e381291\nskEm: 7550253e1147aae48839c1f8af80d2770fb7a4c763afe7d0afa7e0f42a5b3689\nikmR: 61092f3f56994dd424405899154a9918353e3e008171517ad576b900ddb275e7\npkRm: 04a697bffde9405c992883c5c439d6cc358170b51af72812333b015621dc0f40bad9bb726f68a5c013806a790ec716ab8669f84f6b694596c2987cf35baba2a006\nskRm: a4d1c55836aa30f9b3fbb6ac98d338c877c2867dd3a77396d13f68d3ab150d3b\nenc: 04c07836a0206e04e31d8ae99bfd549380b072a1b1b82e563c935c095827824fc1559eac6fb9e3c70cd3193968994e7fe9781aa103f5b50e934b5b2f387e381291\nshared_secret: 806520f82ef0b03c823b7fc524b6b55a088f566b9751b89551c170f4113bd850\nkey_schedule_context: 00b738cd703db7b4106e93b4621e9a19c89c838e55964240e5d3f331aaf8b0d58b2e986ea1c671b61cf45eec134dac0bae58ec6f63e790b1400b47c33038b0269c\nsecret: fe891101629aa355aad68eff3cc5170d057eca0c7573f6575e91f9783e1d4506\nkey: a8f45490a92a3b04d1dbf6cf2c3939ad8bfc9bfcb97c04bffe116730c9dfe3fc\nbase_nonce: 726b4390ed2209809f58c693\nexporter_secret: 4f9bd9b3a8db7d7c3a5b9d44fdc1f6e37d5d77689ade5ec44a7242016e6aa205", + "Encryptions": "sequence number: 0\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d30\nnonce: 726b4390ed2209809f58c693\nct: 6469c41c5c81d3aa85432531ecf6460ec945bde1eb428cb2fedf7a29f5a685b4ccb0d057f03ea2952a27bb458b\n\nsequence number: 1\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d31\nnonce: 726b4390ed2209809f58c692\nct: f1564199f7e0e110ec9c1bcdde332177fc35c1adf6e57f8d1df24022227ffa8716862dbda2b1dc546c9d114374\n\nsequence number: 2\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d32\nnonce: 726b4390ed2209809f58c691\nct: 39de89728bcb774269f882af8dc5369e4f3d6322d986e872b3a8d074c7c18e8549ff3f85b6d6592ff87c3f310c\n\nsequence number: 4\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d34\nnonce: 726b4390ed2209809f58c697\nct: bc104a14fbede0cc79eeb826ea0476ce87b9c928c36e5e34dc9b6905d91473ec369a08b1a25d305dd45c6c5f80\n\nsequence number: 255\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d323535\nnonce: 726b4390ed2209809f58c66c\nct: 8f2814a2c548b3be50259713c6724009e092d37789f6856553d61df23ebc079235f710e6af3c3ca6eaba7c7c6c\n\nsequence number: 256\npt: 4265617574792069732074727574682c20747275746820626561757479\naad: 436f756e742d323536\nnonce: 726b4390ed2209809f58c793\nct: b45b69d419a9be7219d8c94365b89ad6951caf4576ea4774ea40e9b7047a09d6537d1aa2f7c12d6ae4b729b4d0" + } +] diff --git a/tag/tag.go b/tag/tag.go new file mode 100644 index 0000000..78326e4 --- /dev/null +++ b/tag/tag.go @@ -0,0 +1,103 @@ +// Copyright 2025 The age Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tag + +import ( + "crypto/ecdh" + "crypto/hkdf" + "crypto/sha256" + "fmt" + + "filippo.io/age" + "filippo.io/age/internal/format" + "filippo.io/age/plugin" + "filippo.io/age/tag/internal/hpke" + "filippo.io/nistec" +) + +type Recipient struct { + kem hpke.KEMSender + + compressed [33]byte + uncompressed [65]byte +} + +var _ age.Recipient = &Recipient{} + +// ParseRecipient returns a new [Recipient] from a Bech32 public key +// encoding with the "age1tag1" prefix. +func ParseRecipient(s string) (*Recipient, error) { + t, k, err := plugin.ParseRecipient(s) + if err != nil { + return nil, fmt.Errorf("malformed recipient %q: %v", s, err) + } + if t != "tag" { + return nil, fmt.Errorf("malformed recipient %q: invalid type %q", s, t) + } + r, err := NewRecipient(k) + if err != nil { + return nil, fmt.Errorf("malformed recipient %q: %v", s, err) + } + return r, nil +} + +// NewRecipient returns a new [Recipient] from a raw public key. +func NewRecipient(publicKey []byte) (*Recipient, error) { + if len(publicKey) != 1+32 { + return nil, fmt.Errorf("invalid tag recipient public key size %d", len(publicKey)) + } + p, err := nistec.NewP256Point().SetBytes(publicKey) + if err != nil { + return nil, fmt.Errorf("invalid tag recipient public key: %v", err) + } + k, err := ecdh.P256().NewPublicKey(p.Bytes()) + if err != nil { + return nil, fmt.Errorf("invalid tag recipient public key: %v", err) + } + kem, err := hpke.DHKEMSender(k) + if err != nil { + return nil, fmt.Errorf("failed to create DHKEM sender: %v", err) + } + r := &Recipient{kem: kem} + copy(r.compressed[:], publicKey) + copy(r.uncompressed[:], p.Bytes()) + return r, nil +} + +var p256TagLabel = []byte("age-encryption.org/p256tag") + +func (r *Recipient) Wrap(fileKey []byte) ([]*age.Stanza, error) { + enc, s, err := hpke.SetupSender(r.kem, + hpke.HKDFSHA256(), hpke.ChaCha20Poly1305(), + p256TagLabel) + if err != nil { + return nil, fmt.Errorf("failed to set up HPKE sender: %v", err) + } + ct, err := s.Seal(nil, fileKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt file key: %v", err) + } + + tag, err := hkdf.Extract(sha256.New, append(enc, r.uncompressed[:]...), p256TagLabel) + if err != nil { + return nil, fmt.Errorf("failed to compute tag: %v", err) + } + + l := &age.Stanza{ + Type: "p256tag", + Args: []string{ + format.EncodeToString(tag[:4]), + format.EncodeToString(enc), + }, + Body: ct, + } + + return []*age.Stanza{l}, nil +} + +// String returns the Bech32 public key encoding of r. +func (r *Recipient) String() string { + return plugin.EncodeRecipient("tag", r.compressed[:]) +}