mirror of
https://github.com/FiloSottile/age.git
synced 2026-01-08 13:01:09 +00:00
age: add EncryptReader pull-based encryption API
Fixes #644 Fixes #654 Updates #638
This commit is contained in:
67
age.go
67
age.go
@@ -115,23 +115,11 @@ type Stanza struct {
|
||||
const fileKeySize = 16
|
||||
const streamNonceSize = 16
|
||||
|
||||
// Encrypt encrypts a file to one or more recipients.
|
||||
//
|
||||
// Writes to the returned WriteCloser are encrypted and written to dst as an age
|
||||
// file. Every recipient will be able to decrypt the file.
|
||||
//
|
||||
// The caller must call Close on the WriteCloser when done for the last chunk to
|
||||
// be encrypted and flushed to dst.
|
||||
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
|
||||
func encryptHdr(fileKey []byte, recipients ...Recipient) (*format.Header, error) {
|
||||
if len(recipients) == 0 {
|
||||
return nil, errors.New("no recipients specified")
|
||||
}
|
||||
|
||||
fileKey := make([]byte, fileKeySize)
|
||||
if _, err := rand.Read(fileKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdr := &format.Header{}
|
||||
var labels []string
|
||||
for i, r := range recipients {
|
||||
@@ -154,19 +142,62 @@ func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
|
||||
} else {
|
||||
hdr.MAC = mac
|
||||
}
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts a file to one or more recipients. Every recipient will be
|
||||
// able to decrypt the file.
|
||||
//
|
||||
// Writes to the returned WriteCloser are encrypted and written to dst as an age
|
||||
// file. The caller must call Close on the WriteCloser when done for the last
|
||||
// chunk to be encrypted and flushed to dst.
|
||||
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
|
||||
fileKey := make([]byte, fileKeySize)
|
||||
rand.Read(fileKey)
|
||||
|
||||
hdr, err := encryptHdr(fileKey, recipients...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := hdr.Marshal(dst); err != nil {
|
||||
return nil, fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, streamNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rand.Read(nonce)
|
||||
if _, err := dst.Write(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to write nonce: %w", err)
|
||||
}
|
||||
|
||||
return stream.NewWriter(streamKey(fileKey, nonce), dst)
|
||||
return stream.NewEncryptWriter(streamKey(fileKey, nonce), dst)
|
||||
}
|
||||
|
||||
// EncryptReader encrypts a file to one or more recipients. Every recipient will be
|
||||
// able to decrypt the file.
|
||||
//
|
||||
// Reads from the returned Reader produce the encrypted file, where the plaintext
|
||||
// is read from src.
|
||||
func EncryptReader(src io.Reader, recipients ...Recipient) (io.Reader, error) {
|
||||
fileKey := make([]byte, fileKeySize)
|
||||
rand.Read(fileKey)
|
||||
|
||||
hdr, err := encryptHdr(fileKey, recipients...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
if err := hdr.Marshal(buf); err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare header: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, streamNonceSize)
|
||||
rand.Read(nonce)
|
||||
|
||||
r, err := stream.NewEncryptReader(streamKey(fileKey, nonce), src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.MultiReader(buf, bytes.NewReader(nonce), r), nil
|
||||
}
|
||||
|
||||
func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) {
|
||||
@@ -244,7 +275,7 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||
return nil, fmt.Errorf("failed to read nonce: %w", err)
|
||||
}
|
||||
|
||||
return stream.NewReader(streamKey(fileKey, nonce), payload)
|
||||
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
|
||||
}
|
||||
|
||||
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
|
||||
|
||||
27
age_test.go
27
age_test.go
@@ -486,3 +486,30 @@ func TestDetachedHeader(t *testing.T) {
|
||||
t.Errorf("wrong data: %q, expected %q", outBytes, helloWorld)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptReader(t *testing.T) {
|
||||
a, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, err := age.EncryptReader(strings.NewReader(helloWorld), a.Recipient())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
if _, err := io.Copy(buf, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
out, err := age.Decrypt(buf, a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
outBytes, err := io.ReadAll(out)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(outBytes) != helloWorld {
|
||||
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
|
||||
const ChunkSize = 64 * 1024
|
||||
|
||||
type Reader struct {
|
||||
type DecryptReader struct {
|
||||
a cipher.AEAD
|
||||
src io.Reader
|
||||
|
||||
@@ -32,18 +33,15 @@ const (
|
||||
lastChunkFlag = 0x01
|
||||
)
|
||||
|
||||
func NewReader(key []byte, src io.Reader) (*Reader, error) {
|
||||
func NewDecryptReader(key []byte, src io.Reader) (*DecryptReader, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Reader{
|
||||
a: aead,
|
||||
src: src,
|
||||
}, nil
|
||||
return &DecryptReader{a: aead, src: src}, nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
func (r *DecryptReader) Read(p []byte) (int, error) {
|
||||
if len(r.unread) > 0 {
|
||||
n := copy(p, r.unread)
|
||||
r.unread = r.unread[n:]
|
||||
@@ -85,7 +83,7 @@ func (r *Reader) Read(p []byte) (int, error) {
|
||||
// readChunk reads the next chunk of ciphertext from r.src and makes it available
|
||||
// in r.unread. last is true if the chunk was marked as the end of the message.
|
||||
// readChunk must not be called again after returning a last chunk or an error.
|
||||
func (r *Reader) readChunk() (last bool, err error) {
|
||||
func (r *DecryptReader) readChunk() (last bool, err error) {
|
||||
if len(r.unread) != 0 {
|
||||
panic("stream: internal error: readChunk called with dirty buffer")
|
||||
}
|
||||
@@ -130,12 +128,11 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
for i := len(nonce) - 2; i >= 0; i-- {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
break
|
||||
} else if i == 0 {
|
||||
// The counter is 88 bits, this is unreachable.
|
||||
panic("stream: chunk counter wrapped around")
|
||||
return
|
||||
}
|
||||
}
|
||||
// The counter is 88 bits, this is unreachable.
|
||||
panic("stream: chunk counter wrapped around")
|
||||
}
|
||||
|
||||
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
@@ -146,30 +143,23 @@ func nonceIsZero(nonce *[chacha20poly1305.NonceSize]byte) bool {
|
||||
return *nonce == [chacha20poly1305.NonceSize]byte{}
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
a cipher.AEAD
|
||||
dst io.Writer
|
||||
unwritten []byte // backed by buf
|
||||
buf [encChunkSize]byte
|
||||
nonce [chacha20poly1305.NonceSize]byte
|
||||
err error
|
||||
type EncryptWriter struct {
|
||||
a cipher.AEAD
|
||||
dst io.Writer
|
||||
buf bytes.Buffer
|
||||
nonce [chacha20poly1305.NonceSize]byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
|
||||
func NewEncryptWriter(key []byte, dst io.Writer) (*EncryptWriter, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := &Writer{
|
||||
a: aead,
|
||||
dst: dst,
|
||||
}
|
||||
w.unwritten = w.buf[:0]
|
||||
return w, nil
|
||||
return &EncryptWriter{a: aead, dst: dst}, nil
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
// TODO: consider refactoring with a bytes.Buffer.
|
||||
func (w *EncryptWriter) Write(p []byte) (n int, err error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
@@ -179,12 +169,13 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
|
||||
total := len(p)
|
||||
for len(p) > 0 {
|
||||
freeBuf := w.buf[len(w.unwritten):ChunkSize]
|
||||
n := copy(freeBuf, p)
|
||||
n := min(len(p), ChunkSize-w.buf.Len())
|
||||
w.buf.Write(p[:n])
|
||||
p = p[n:]
|
||||
w.unwritten = w.unwritten[:len(w.unwritten)+n]
|
||||
|
||||
if len(w.unwritten) == ChunkSize && len(p) > 0 {
|
||||
// Only flush if there's a full chunk with bytes still to write, or we
|
||||
// can't know if this is the last chunk yet.
|
||||
if w.buf.Len() == ChunkSize && len(p) > 0 {
|
||||
if err := w.flushChunk(notLastChunk); err != nil {
|
||||
w.err = err
|
||||
return 0, err
|
||||
@@ -195,7 +186,7 @@ func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
// Close flushes the last chunk. It does not close the underlying Writer.
|
||||
func (w *Writer) Close() error {
|
||||
func (w *EncryptWriter) Close() error {
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
@@ -214,17 +205,110 @@ const (
|
||||
notLastChunk = false
|
||||
)
|
||||
|
||||
func (w *Writer) flushChunk(last bool) error {
|
||||
if !last && len(w.unwritten) != ChunkSize {
|
||||
func (w *EncryptWriter) flushChunk(last bool) error {
|
||||
if !last && w.buf.Len() != ChunkSize {
|
||||
panic("stream: internal error: flush called with partial chunk")
|
||||
}
|
||||
|
||||
if last {
|
||||
setLastChunkFlag(&w.nonce)
|
||||
}
|
||||
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
|
||||
_, err := w.dst.Write(buf)
|
||||
w.unwritten = w.buf[:0]
|
||||
w.buf.Grow(chacha20poly1305.Overhead)
|
||||
ciphertext := w.a.Seal(w.buf.Bytes()[:0], w.nonce[:], w.buf.Bytes(), nil)
|
||||
_, err := w.dst.Write(ciphertext)
|
||||
incNonce(&w.nonce)
|
||||
w.buf.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
type EncryptReader struct {
|
||||
a cipher.AEAD
|
||||
src io.Reader
|
||||
|
||||
// The first ready bytes of buf are already encrypted. This may be less than
|
||||
// buf.Len(), because we need to over-read to know if a chunk is the last.
|
||||
ready int
|
||||
buf bytes.Buffer
|
||||
|
||||
nonce [chacha20poly1305.NonceSize]byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewEncryptReader(key []byte, src io.Reader) (*EncryptReader, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &EncryptReader{a: aead, src: src}, nil
|
||||
}
|
||||
|
||||
func (r *EncryptReader) Read(p []byte) (int, error) {
|
||||
if r.ready > 0 {
|
||||
n, err := r.buf.Read(p[:min(len(p), r.ready)])
|
||||
r.ready -= n
|
||||
return n, err
|
||||
}
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if err := r.feedBuffer(); err != nil {
|
||||
r.err = err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n, err := r.buf.Read(p[:min(len(p), r.ready)])
|
||||
r.ready -= n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// feedBuffer reads and encrypts the next chunk from r.src and appends it to
|
||||
// r.buf. It sets r.ready to the number of newly available bytes in r.buf.
|
||||
func (r *EncryptReader) feedBuffer() error {
|
||||
if r.ready > 0 {
|
||||
panic("stream: internal error: feedBuffer called with dirty buffer")
|
||||
}
|
||||
|
||||
// CopyN will use r.buf.ReadFrom/WriteTo to fill the buffer directly.
|
||||
// We need ChunkSize + 1 bytes to determine if this is the last chunk.
|
||||
_, err := io.CopyN(&r.buf, r.src, int64(ChunkSize-r.buf.Len()+1))
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
if last := r.buf.Len() <= ChunkSize; last {
|
||||
setLastChunkFlag(&r.nonce)
|
||||
|
||||
// After Grow, we know r.buf.Bytes() has enough capacity for the
|
||||
// overhead. We encrypt in place and then do a Write to include the
|
||||
// overhead in the buffer.
|
||||
r.buf.Grow(chacha20poly1305.Overhead)
|
||||
plaintext := r.buf.Bytes()
|
||||
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
|
||||
incNonce(&r.nonce)
|
||||
r.buf.Write(plaintext[len(plaintext) : len(plaintext)+chacha20poly1305.Overhead])
|
||||
r.ready = r.buf.Len()
|
||||
|
||||
r.err = io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
// Same, but accounting for the tail byte which will remain unencrypted and
|
||||
// needs to be shifted past the overhead.
|
||||
if r.buf.Len() != ChunkSize+1 {
|
||||
panic("stream: internal error: unexpected buffer length")
|
||||
}
|
||||
tailByte := r.buf.Bytes()[ChunkSize]
|
||||
r.buf.Grow(chacha20poly1305.Overhead)
|
||||
plaintext := r.buf.Bytes()[:ChunkSize]
|
||||
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
|
||||
incNonce(&r.nonce)
|
||||
r.buf.Write(plaintext[len(plaintext)+1 : len(plaintext)+chacha20poly1305.Overhead])
|
||||
r.buf.WriteByte(tailByte)
|
||||
r.ready = ChunkSize + chacha20poly1305.Overhead
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"filippo.io/age/internal/stream"
|
||||
@@ -17,12 +18,15 @@ import (
|
||||
const cs = stream.ChunkSize
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
for _, stepSize := range []int{512, 600, 1000, cs} {
|
||||
for _, length := range []int{0, 1000, cs, cs + 100} {
|
||||
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
|
||||
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
||||
}
|
||||
}
|
||||
length, stepSize := 2*cs+500, 1
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
||||
}
|
||||
|
||||
func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
@@ -36,17 +40,14 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w, err := stream.NewWriter(key, buf)
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
for n < length {
|
||||
b := length - n
|
||||
if b > stepSize {
|
||||
b = stepSize
|
||||
}
|
||||
b := min(length-n, stepSize)
|
||||
nn, err := w.Write(src[n : n+b])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -70,8 +71,9 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
}
|
||||
|
||||
t.Logf("buffer size: %d", buf.Len())
|
||||
ciphertext := bytes.Clone(buf.Bytes())
|
||||
|
||||
r, err := stream.NewReader(key, buf)
|
||||
r, err := stream.NewDecryptReader(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -90,4 +92,27 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
|
||||
n += nn
|
||||
}
|
||||
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
n = 0
|
||||
for {
|
||||
nn, err := er.Read(readBuf)
|
||||
if nn == 0 && err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
|
||||
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
}
|
||||
if n != len(ciphertext) {
|
||||
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,7 +288,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
t.Run("STREAM", func(t *testing.T) {
|
||||
nonce, payload := payload[:16], payload[16:]
|
||||
key := streamKey(v.fileKey[:], nonce)
|
||||
r, err := stream.NewReader(key, bytes.NewReader(payload))
|
||||
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewWriter(key, buf)
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -310,6 +310,18 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
if !bytes.Equal(buf.Bytes(), payload) {
|
||||
t.Error("got a different STREAM ciphertext")
|
||||
}
|
||||
buf.Reset()
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext, err := io.ReadAll(er)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(ciphertext, payload) {
|
||||
t.Error("got a different STREAM ciphertext from EncryptReader")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user