From ec92694aadbd74ef4aaca30e39c2727a45c7fd02 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Thu, 25 Dec 2025 19:50:48 +0100 Subject: [PATCH] age: add EncryptReader pull-based encryption API Fixes #644 Fixes #654 Updates #638 --- age.go | 67 ++++++++++---- age_test.go | 27 ++++++ internal/stream/stream.go | 160 +++++++++++++++++++++++++-------- internal/stream/stream_test.go | 41 +++++++-- testkit_test.go | 16 +++- 5 files changed, 245 insertions(+), 66 deletions(-) diff --git a/age.go b/age.go index d2a264d..f255a51 100644 --- a/age.go +++ b/age.go @@ -115,23 +115,11 @@ type Stanza struct { const fileKeySize = 16 const streamNonceSize = 16 -// Encrypt encrypts a file to one or more recipients. -// -// Writes to the returned WriteCloser are encrypted and written to dst as an age -// file. Every recipient will be able to decrypt the file. -// -// The caller must call Close on the WriteCloser when done for the last chunk to -// be encrypted and flushed to dst. -func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { +func encryptHdr(fileKey []byte, recipients ...Recipient) (*format.Header, error) { if len(recipients) == 0 { return nil, errors.New("no recipients specified") } - fileKey := make([]byte, fileKeySize) - if _, err := rand.Read(fileKey); err != nil { - return nil, err - } - hdr := &format.Header{} var labels []string for i, r := range recipients { @@ -154,19 +142,62 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { } else { hdr.MAC = mac } + return hdr, nil +} + +// Encrypt encrypts a file to one or more recipients. Every recipient will be +// able to decrypt the file. +// +// Writes to the returned WriteCloser are encrypted and written to dst as an age +// file. The caller must call Close on the WriteCloser when done for the last +// chunk to be encrypted and flushed to dst. +func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) { + fileKey := make([]byte, fileKeySize) + rand.Read(fileKey) + + hdr, err := encryptHdr(fileKey, recipients...) + if err != nil { + return nil, err + } if err := hdr.Marshal(dst); err != nil { return nil, fmt.Errorf("failed to write header: %w", err) } nonce := make([]byte, streamNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } + rand.Read(nonce) if _, err := dst.Write(nonce); err != nil { return nil, fmt.Errorf("failed to write nonce: %w", err) } - return stream.NewWriter(streamKey(fileKey, nonce), dst) + return stream.NewEncryptWriter(streamKey(fileKey, nonce), dst) +} + +// EncryptReader encrypts a file to one or more recipients. Every recipient will be +// able to decrypt the file. +// +// Reads from the returned Reader produce the encrypted file, where the plaintext +// is read from src. +func EncryptReader(src io.Reader, recipients ...Recipient) (io.Reader, error) { + fileKey := make([]byte, fileKeySize) + rand.Read(fileKey) + + hdr, err := encryptHdr(fileKey, recipients...) + if err != nil { + return nil, err + } + buf := &bytes.Buffer{} + if err := hdr.Marshal(buf); err != nil { + return nil, fmt.Errorf("failed to prepare header: %w", err) + } + + nonce := make([]byte, streamNonceSize) + rand.Read(nonce) + + r, err := stream.NewEncryptReader(streamKey(fileKey, nonce), src) + if err != nil { + return nil, err + } + return io.MultiReader(buf, bytes.NewReader(nonce), r), nil } func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) { @@ -244,7 +275,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { return nil, fmt.Errorf("failed to read nonce: %w", err) } - return stream.NewReader(streamKey(fileKey, nonce), payload) + return stream.NewDecryptReader(streamKey(fileKey, nonce), payload) } func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) { diff --git a/age_test.go b/age_test.go index 48d6034..bbe9e26 100644 --- a/age_test.go +++ b/age_test.go @@ -486,3 +486,30 @@ func TestDetachedHeader(t *testing.T) { t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld) } } + +func TestEncryptReader(t *testing.T) { + a, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + r, err := age.EncryptReader(strings.NewReader(helloWorld), a.Recipient()) + if err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + if _, err := io.Copy(buf, r); err != nil { + t.Fatal(err) + } + + out, err := age.Decrypt(buf, a) + if err != nil { + t.Fatal(err) + } + outBytes, err := io.ReadAll(out) + if err != nil { + t.Fatal(err) + } + if string(outBytes) != helloWorld { + t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld) + } +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index e781700..8dfae38 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -6,6 +6,7 @@ package stream import ( + "bytes" "crypto/cipher" "errors" "fmt" @@ -16,7 +17,7 @@ import ( const ChunkSize = 64 * 1024 -type Reader struct { +type DecryptReader struct { a cipher.AEAD src io.Reader @@ -32,18 +33,15 @@ const ( lastChunkFlag = 0x01 ) -func NewReader(key []byte, src io.Reader) (*Reader, error) { +func NewDecryptReader(key []byte, src io.Reader) (*DecryptReader, error) { aead, err := chacha20poly1305.New(key) if err != nil { return nil, err } - return &Reader{ - a: aead, - src: src, - }, nil + return &DecryptReader{a: aead, src: src}, nil } -func (r *Reader) Read(p []byte) (int, error) { +func (r *DecryptReader) Read(p []byte) (int, error) { if len(r.unread) > 0 { n := copy(p, r.unread) r.unread = r.unread[n:] @@ -85,7 +83,7 @@ func (r *Reader) Read(p []byte) (int, error) { // readChunk reads the next chunk of ciphertext from r.src and makes it available // in r.unread. last is true if the chunk was marked as the end of the message. // readChunk must not be called again after returning a last chunk or an error. -func (r *Reader) readChunk() (last bool, err error) { +func (r *DecryptReader) readChunk() (last bool, err error) { if len(r.unread) != 0 { panic("stream: internal error: readChunk called with dirty buffer") } @@ -130,12 +128,11 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) { for i := len(nonce) - 2; i >= 0; i-- { nonce[i]++ if nonce[i] != 0 { - break - } else if i == 0 { - // The counter is 88 bits, this is unreachable. - panic("stream: chunk counter wrapped around") + return } } + // The counter is 88 bits, this is unreachable. + panic("stream: chunk counter wrapped around") } func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) { @@ -146,30 +143,23 @@ func nonceIsZero(nonce *[chacha20poly1305.NonceSize]byte) bool { return *nonce == [chacha20poly1305.NonceSize]byte{} } -type Writer struct { - a cipher.AEAD - dst io.Writer - unwritten []byte // backed by buf - buf [encChunkSize]byte - nonce [chacha20poly1305.NonceSize]byte - err error +type EncryptWriter struct { + a cipher.AEAD + dst io.Writer + buf bytes.Buffer + nonce [chacha20poly1305.NonceSize]byte + err error } -func NewWriter(key []byte, dst io.Writer) (*Writer, error) { +func NewEncryptWriter(key []byte, dst io.Writer) (*EncryptWriter, error) { aead, err := chacha20poly1305.New(key) if err != nil { return nil, err } - w := &Writer{ - a: aead, - dst: dst, - } - w.unwritten = w.buf[:0] - return w, nil + return &EncryptWriter{a: aead, dst: dst}, nil } -func (w *Writer) Write(p []byte) (n int, err error) { - // TODO: consider refactoring with a bytes.Buffer. +func (w *EncryptWriter) Write(p []byte) (n int, err error) { if w.err != nil { return 0, w.err } @@ -179,12 +169,13 @@ func (w *Writer) Write(p []byte) (n int, err error) { total := len(p) for len(p) > 0 { - freeBuf := w.buf[len(w.unwritten):ChunkSize] - n := copy(freeBuf, p) + n := min(len(p), ChunkSize-w.buf.Len()) + w.buf.Write(p[:n]) p = p[n:] - w.unwritten = w.unwritten[:len(w.unwritten)+n] - if len(w.unwritten) == ChunkSize && len(p) > 0 { + // Only flush if there's a full chunk with bytes still to write, or we + // can't know if this is the last chunk yet. + if w.buf.Len() == ChunkSize && len(p) > 0 { if err := w.flushChunk(notLastChunk); err != nil { w.err = err return 0, err @@ -195,7 +186,7 @@ func (w *Writer) Write(p []byte) (n int, err error) { } // Close flushes the last chunk. It does not close the underlying Writer. -func (w *Writer) Close() error { +func (w *EncryptWriter) Close() error { if w.err != nil { return w.err } @@ -214,17 +205,110 @@ const ( notLastChunk = false ) -func (w *Writer) flushChunk(last bool) error { - if !last && len(w.unwritten) != ChunkSize { +func (w *EncryptWriter) flushChunk(last bool) error { + if !last && w.buf.Len() != ChunkSize { panic("stream: internal error: flush called with partial chunk") } if last { setLastChunkFlag(&w.nonce) } - buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil) - _, err := w.dst.Write(buf) - w.unwritten = w.buf[:0] + w.buf.Grow(chacha20poly1305.Overhead) + ciphertext := w.a.Seal(w.buf.Bytes()[:0], w.nonce[:], w.buf.Bytes(), nil) + _, err := w.dst.Write(ciphertext) incNonce(&w.nonce) + w.buf.Reset() return err } + +type EncryptReader struct { + a cipher.AEAD + src io.Reader + + // The first ready bytes of buf are already encrypted. This may be less than + // buf.Len(), because we need to over-read to know if a chunk is the last. + ready int + buf bytes.Buffer + + nonce [chacha20poly1305.NonceSize]byte + err error +} + +func NewEncryptReader(key []byte, src io.Reader) (*EncryptReader, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + return &EncryptReader{a: aead, src: src}, nil +} + +func (r *EncryptReader) Read(p []byte) (int, error) { + if r.ready > 0 { + n, err := r.buf.Read(p[:min(len(p), r.ready)]) + r.ready -= n + return n, err + } + if r.err != nil { + return 0, r.err + } + if len(p) == 0 { + return 0, nil + } + + if err := r.feedBuffer(); err != nil { + r.err = err + return 0, err + } + + n, err := r.buf.Read(p[:min(len(p), r.ready)]) + r.ready -= n + return n, err +} + +// feedBuffer reads and encrypts the next chunk from r.src and appends it to +// r.buf. It sets r.ready to the number of newly available bytes in r.buf. +func (r *EncryptReader) feedBuffer() error { + if r.ready > 0 { + panic("stream: internal error: feedBuffer called with dirty buffer") + } + + // CopyN will use r.buf.ReadFrom/WriteTo to fill the buffer directly. + // We need ChunkSize + 1 bytes to determine if this is the last chunk. + _, err := io.CopyN(&r.buf, r.src, int64(ChunkSize-r.buf.Len()+1)) + if err != nil && err != io.EOF { + return err + } + + if last := r.buf.Len() <= ChunkSize; last { + setLastChunkFlag(&r.nonce) + + // After Grow, we know r.buf.Bytes() has enough capacity for the + // overhead. We encrypt in place and then do a Write to include the + // overhead in the buffer. + r.buf.Grow(chacha20poly1305.Overhead) + plaintext := r.buf.Bytes() + r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil) + incNonce(&r.nonce) + r.buf.Write(plaintext[len(plaintext) : len(plaintext)+chacha20poly1305.Overhead]) + r.ready = r.buf.Len() + + r.err = io.EOF + return nil + } + + // Same, but accounting for the tail byte which will remain unencrypted and + // needs to be shifted past the overhead. + if r.buf.Len() != ChunkSize+1 { + panic("stream: internal error: unexpected buffer length") + } + tailByte := r.buf.Bytes()[ChunkSize] + r.buf.Grow(chacha20poly1305.Overhead) + plaintext := r.buf.Bytes()[:ChunkSize] + r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil) + incNonce(&r.nonce) + r.buf.Write(plaintext[len(plaintext)+1 : len(plaintext)+chacha20poly1305.Overhead]) + r.buf.WriteByte(tailByte) + r.ready = ChunkSize + chacha20poly1305.Overhead + + return nil +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 8cac967..e2f6961 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -8,6 +8,7 @@ import ( "bytes" "crypto/rand" "fmt" + "io" "testing" "filippo.io/age/internal/stream" @@ -17,12 +18,15 @@ import ( const cs = stream.ChunkSize func TestRoundTrip(t *testing.T) { - for _, stepSize := range []int{512, 600, 1000, cs} { - for _, length := range []int{0, 1000, cs, cs + 100} { + for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} { + for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} { t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) { testRoundTrip(t, stepSize, length) }) } } + length, stepSize := 2*cs+500, 1 + t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), + func(t *testing.T) { testRoundTrip(t, stepSize, length) }) } func testRoundTrip(t *testing.T, stepSize, length int) { @@ -36,17 +40,14 @@ func testRoundTrip(t *testing.T, stepSize, length int) { t.Fatal(err) } - w, err := stream.NewWriter(key, buf) + w, err := stream.NewEncryptWriter(key, buf) if err != nil { t.Fatal(err) } var n int for n < length { - b := length - n - if b > stepSize { - b = stepSize - } + b := min(length-n, stepSize) nn, err := w.Write(src[n : n+b]) if err != nil { t.Fatal(err) @@ -70,8 +71,9 @@ func testRoundTrip(t *testing.T, stepSize, length int) { } t.Logf("buffer size: %d", buf.Len()) + ciphertext := bytes.Clone(buf.Bytes()) - r, err := stream.NewReader(key, buf) + r, err := stream.NewDecryptReader(key, buf) if err != nil { t.Fatal(err) } @@ -90,4 +92,27 @@ func testRoundTrip(t *testing.T, stepSize, length int) { n += nn } + + er, err := stream.NewEncryptReader(key, bytes.NewReader(src)) + if err != nil { + t.Fatal(err) + } + n = 0 + for { + nn, err := er.Read(readBuf) + if nn == 0 && err == io.EOF { + break + } else if err != nil { + t.Fatalf("EncryptReader Read error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) { + t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } + if n != len(ciphertext) { + t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext)) + } } diff --git a/testkit_test.go b/testkit_test.go index 03cf497..5e4f271 100644 --- a/testkit_test.go +++ b/testkit_test.go @@ -288,7 +288,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) { t.Run("STREAM", func(t *testing.T) { nonce, payload := payload[:16], payload[16:] key := streamKey(v.fileKey[:], nonce) - r, err := stream.NewReader(key, bytes.NewReader(payload)) + r, err := stream.NewDecryptReader(key, bytes.NewReader(payload)) if err != nil { t.Fatal(err) } @@ -297,7 +297,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) { t.Fatal(err) } buf := &bytes.Buffer{} - w, err := stream.NewWriter(key, buf) + w, err := stream.NewEncryptWriter(key, buf) if err != nil { t.Fatal(err) } @@ -310,6 +310,18 @@ func testVectorRoundTrip(t *testing.T, v *vector) { if !bytes.Equal(buf.Bytes(), payload) { t.Error("got a different STREAM ciphertext") } + buf.Reset() + er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext)) + if err != nil { + t.Fatal(err) + } + ciphertext, err := io.ReadAll(er) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(ciphertext, payload) { + t.Error("got a different STREAM ciphertext from EncryptReader") + } }) } }