age: add EncryptReader pull-based encryption API

Fixes #644
Fixes #654
Updates #638
This commit is contained in:
Filippo Valsorda
2025-12-25 19:50:48 +01:00
parent 92ac13f51c
commit ec92694aad
5 changed files with 245 additions and 66 deletions

67
age.go
View File

@@ -115,23 +115,11 @@ type Stanza struct {
const fileKeySize = 16
const streamNonceSize = 16
// Encrypt encrypts a file to one or more recipients.
//
// Writes to the returned WriteCloser are encrypted and written to dst as an age
// file. Every recipient will be able to decrypt the file.
//
// The caller must call Close on the WriteCloser when done for the last chunk to
// be encrypted and flushed to dst.
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
func encryptHdr(fileKey []byte, recipients ...Recipient) (*format.Header, error) {
if len(recipients) == 0 {
return nil, errors.New("no recipients specified")
}
fileKey := make([]byte, fileKeySize)
if _, err := rand.Read(fileKey); err != nil {
return nil, err
}
hdr := &format.Header{}
var labels []string
for i, r := range recipients {
@@ -154,19 +142,62 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
} else {
hdr.MAC = mac
}
return hdr, nil
}
// Encrypt encrypts a file to one or more recipients. Every recipient will be
// able to decrypt the file.
//
// Writes to the returned WriteCloser are encrypted and written to dst as an age
// file. The caller must call Close on the WriteCloser when done for the last
// chunk to be encrypted and flushed to dst.
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
fileKey := make([]byte, fileKeySize)
rand.Read(fileKey)
hdr, err := encryptHdr(fileKey, recipients...)
if err != nil {
return nil, err
}
if err := hdr.Marshal(dst); err != nil {
return nil, fmt.Errorf("failed to write header: %w", err)
}
nonce := make([]byte, streamNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
rand.Read(nonce)
if _, err := dst.Write(nonce); err != nil {
return nil, fmt.Errorf("failed to write nonce: %w", err)
}
return stream.NewWriter(streamKey(fileKey, nonce), dst)
return stream.NewEncryptWriter(streamKey(fileKey, nonce), dst)
}
// EncryptReader encrypts a file to one or more recipients. Every recipient will be
// able to decrypt the file.
//
// Reads from the returned Reader produce the encrypted file, where the plaintext
// is read from src.
func EncryptReader(src io.Reader, recipients ...Recipient) (io.Reader, error) {
fileKey := make([]byte, fileKeySize)
rand.Read(fileKey)
hdr, err := encryptHdr(fileKey, recipients...)
if err != nil {
return nil, err
}
buf := &bytes.Buffer{}
if err := hdr.Marshal(buf); err != nil {
return nil, fmt.Errorf("failed to prepare header: %w", err)
}
nonce := make([]byte, streamNonceSize)
rand.Read(nonce)
r, err := stream.NewEncryptReader(streamKey(fileKey, nonce), src)
if err != nil {
return nil, err
}
return io.MultiReader(buf, bytes.NewReader(nonce), r), nil
}
func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) {
@@ -244,7 +275,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
return nil, fmt.Errorf("failed to read nonce: %w", err)
}
return stream.NewReader(streamKey(fileKey, nonce), payload)
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
}
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {

View File

@@ -486,3 +486,30 @@ func TestDetachedHeader(t *testing.T) {
t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld)
}
}
func TestEncryptReader(t *testing.T) {
a, err := age.GenerateX25519Identity()
if err != nil {
t.Fatal(err)
}
r, err := age.EncryptReader(strings.NewReader(helloWorld), a.Recipient())
if err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, r); err != nil {
t.Fatal(err)
}
out, err := age.Decrypt(buf, a)
if err != nil {
t.Fatal(err)
}
outBytes, err := io.ReadAll(out)
if err != nil {
t.Fatal(err)
}
if string(outBytes) != helloWorld {
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
}
}

View File

@@ -6,6 +6,7 @@
package stream
import (
"bytes"
"crypto/cipher"
"errors"
"fmt"
@@ -16,7 +17,7 @@ import (
const ChunkSize = 64 * 1024
type Reader struct {
type DecryptReader struct {
a cipher.AEAD
src io.Reader
@@ -32,18 +33,15 @@ const (
lastChunkFlag = 0x01
)
func NewReader(key []byte, src io.Reader) (*Reader, error) {
func NewDecryptReader(key []byte, src io.Reader) (*DecryptReader, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &Reader{
a: aead,
src: src,
}, nil
return &DecryptReader{a: aead, src: src}, nil
}
func (r *Reader) Read(p []byte) (int, error) {
func (r *DecryptReader) Read(p []byte) (int, error) {
if len(r.unread) > 0 {
n := copy(p, r.unread)
r.unread = r.unread[n:]
@@ -85,7 +83,7 @@ func (r *Reader) Read(p []byte) (int, error) {
// readChunk reads the next chunk of ciphertext from r.src and makes it available
// in r.unread. last is true if the chunk was marked as the end of the message.
// readChunk must not be called again after returning a last chunk or an error.
func (r *Reader) readChunk() (last bool, err error) {
func (r *DecryptReader) readChunk() (last bool, err error) {
if len(r.unread) != 0 {
panic("stream: internal error: readChunk called with dirty buffer")
}
@@ -130,12 +128,11 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
for i := len(nonce) - 2; i >= 0; i-- {
nonce[i]++
if nonce[i] != 0 {
break
} else if i == 0 {
// The counter is 88 bits, this is unreachable.
panic("stream: chunk counter wrapped around")
return
}
}
// The counter is 88 bits, this is unreachable.
panic("stream: chunk counter wrapped around")
}
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
@@ -146,30 +143,23 @@ func nonceIsZero(nonce *[chacha20poly1305.NonceSize]byte) bool {
return *nonce == [chacha20poly1305.NonceSize]byte{}
}
type Writer struct {
a cipher.AEAD
dst io.Writer
unwritten []byte // backed by buf
buf [encChunkSize]byte
nonce [chacha20poly1305.NonceSize]byte
err error
type EncryptWriter struct {
a cipher.AEAD
dst io.Writer
buf bytes.Buffer
nonce [chacha20poly1305.NonceSize]byte
err error
}
func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
func NewEncryptWriter(key []byte, dst io.Writer) (*EncryptWriter, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
w := &Writer{
a: aead,
dst: dst,
}
w.unwritten = w.buf[:0]
return w, nil
return &EncryptWriter{a: aead, dst: dst}, nil
}
func (w *Writer) Write(p []byte) (n int, err error) {
// TODO: consider refactoring with a bytes.Buffer.
func (w *EncryptWriter) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
@@ -179,12 +169,13 @@ func (w *Writer) Write(p []byte) (n int, err error) {
total := len(p)
for len(p) > 0 {
freeBuf := w.buf[len(w.unwritten):ChunkSize]
n := copy(freeBuf, p)
n := min(len(p), ChunkSize-w.buf.Len())
w.buf.Write(p[:n])
p = p[n:]
w.unwritten = w.unwritten[:len(w.unwritten)+n]
if len(w.unwritten) == ChunkSize && len(p) > 0 {
// Only flush if there's a full chunk with bytes still to write, or we
// can't know if this is the last chunk yet.
if w.buf.Len() == ChunkSize && len(p) > 0 {
if err := w.flushChunk(notLastChunk); err != nil {
w.err = err
return 0, err
@@ -195,7 +186,7 @@ func (w *Writer) Write(p []byte) (n int, err error) {
}
// Close flushes the last chunk. It does not close the underlying Writer.
func (w *Writer) Close() error {
func (w *EncryptWriter) Close() error {
if w.err != nil {
return w.err
}
@@ -214,17 +205,110 @@ const (
notLastChunk = false
)
func (w *Writer) flushChunk(last bool) error {
if !last && len(w.unwritten) != ChunkSize {
func (w *EncryptWriter) flushChunk(last bool) error {
if !last && w.buf.Len() != ChunkSize {
panic("stream: internal error: flush called with partial chunk")
}
if last {
setLastChunkFlag(&w.nonce)
}
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
_, err := w.dst.Write(buf)
w.unwritten = w.buf[:0]
w.buf.Grow(chacha20poly1305.Overhead)
ciphertext := w.a.Seal(w.buf.Bytes()[:0], w.nonce[:], w.buf.Bytes(), nil)
_, err := w.dst.Write(ciphertext)
incNonce(&w.nonce)
w.buf.Reset()
return err
}
type EncryptReader struct {
a cipher.AEAD
src io.Reader
// The first ready bytes of buf are already encrypted. This may be less than
// buf.Len(), because we need to over-read to know if a chunk is the last.
ready int
buf bytes.Buffer
nonce [chacha20poly1305.NonceSize]byte
err error
}
func NewEncryptReader(key []byte, src io.Reader) (*EncryptReader, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &EncryptReader{a: aead, src: src}, nil
}
func (r *EncryptReader) Read(p []byte) (int, error) {
if r.ready > 0 {
n, err := r.buf.Read(p[:min(len(p), r.ready)])
r.ready -= n
return n, err
}
if r.err != nil {
return 0, r.err
}
if len(p) == 0 {
return 0, nil
}
if err := r.feedBuffer(); err != nil {
r.err = err
return 0, err
}
n, err := r.buf.Read(p[:min(len(p), r.ready)])
r.ready -= n
return n, err
}
// feedBuffer reads and encrypts the next chunk from r.src and appends it to
// r.buf. It sets r.ready to the number of newly available bytes in r.buf.
func (r *EncryptReader) feedBuffer() error {
if r.ready > 0 {
panic("stream: internal error: feedBuffer called with dirty buffer")
}
// CopyN will use r.buf.ReadFrom/WriteTo to fill the buffer directly.
// We need ChunkSize + 1 bytes to determine if this is the last chunk.
_, err := io.CopyN(&r.buf, r.src, int64(ChunkSize-r.buf.Len()+1))
if err != nil && err != io.EOF {
return err
}
if last := r.buf.Len() <= ChunkSize; last {
setLastChunkFlag(&r.nonce)
// After Grow, we know r.buf.Bytes() has enough capacity for the
// overhead. We encrypt in place and then do a Write to include the
// overhead in the buffer.
r.buf.Grow(chacha20poly1305.Overhead)
plaintext := r.buf.Bytes()
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
incNonce(&r.nonce)
r.buf.Write(plaintext[len(plaintext) : len(plaintext)+chacha20poly1305.Overhead])
r.ready = r.buf.Len()
r.err = io.EOF
return nil
}
// Same, but accounting for the tail byte which will remain unencrypted and
// needs to be shifted past the overhead.
if r.buf.Len() != ChunkSize+1 {
panic("stream: internal error: unexpected buffer length")
}
tailByte := r.buf.Bytes()[ChunkSize]
r.buf.Grow(chacha20poly1305.Overhead)
plaintext := r.buf.Bytes()[:ChunkSize]
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
incNonce(&r.nonce)
r.buf.Write(plaintext[len(plaintext)+1 : len(plaintext)+chacha20poly1305.Overhead])
r.buf.WriteByte(tailByte)
r.ready = ChunkSize + chacha20poly1305.Overhead
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"bytes"
"crypto/rand"
"fmt"
"io"
"testing"
"filippo.io/age/internal/stream"
@@ -17,12 +18,15 @@ import (
const cs = stream.ChunkSize
func TestRoundTrip(t *testing.T) {
for _, stepSize := range []int{512, 600, 1000, cs} {
for _, length := range []int{0, 1000, cs, cs + 100} {
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
}
}
length, stepSize := 2*cs+500, 1
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
}
func testRoundTrip(t *testing.T, stepSize, length int) {
@@ -36,17 +40,14 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
t.Fatal(err)
}
w, err := stream.NewWriter(key, buf)
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
var n int
for n < length {
b := length - n
if b > stepSize {
b = stepSize
}
b := min(length-n, stepSize)
nn, err := w.Write(src[n : n+b])
if err != nil {
t.Fatal(err)
@@ -70,8 +71,9 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
}
t.Logf("buffer size: %d", buf.Len())
ciphertext := bytes.Clone(buf.Bytes())
r, err := stream.NewReader(key, buf)
r, err := stream.NewDecryptReader(key, buf)
if err != nil {
t.Fatal(err)
}
@@ -90,4 +92,27 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
n += nn
}
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
if err != nil {
t.Fatal(err)
}
n = 0
for {
nn, err := er.Read(readBuf)
if nn == 0 && err == io.EOF {
break
} else if err != nil {
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
if n != len(ciphertext) {
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
}
}

View File

@@ -288,7 +288,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
t.Run("STREAM", func(t *testing.T) {
nonce, payload := payload[:16], payload[16:]
key := streamKey(v.fileKey[:], nonce)
r, err := stream.NewReader(key, bytes.NewReader(payload))
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
if err != nil {
t.Fatal(err)
}
@@ -297,7 +297,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewWriter(key, buf)
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
@@ -310,6 +310,18 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
if !bytes.Equal(buf.Bytes(), payload) {
t.Error("got a different STREAM ciphertext")
}
buf.Reset()
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
if err != nil {
t.Fatal(err)
}
ciphertext, err := io.ReadAll(er)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, payload) {
t.Error("got a different STREAM ciphertext from EncryptReader")
}
})
}
}