age: add DecryptReaderAt

This commit is contained in:
Filippo Valsorda
2025-12-26 13:54:44 +01:00
parent abe371e157
commit 2ff5d341f6
6 changed files with 1110 additions and 128 deletions

61
age.go
View File

@@ -252,13 +252,12 @@ func (e *NoIdentityMatchError) Unwrap() []error {
}
// Decrypt decrypts a file encrypted to one or more identities.
//
// It returns a Reader reading the decrypted plaintext of the age file read
// from src. All identities will be tried until one successfully decrypts the file.
// All identities will be tried until one successfully decrypts the file.
// Native, non-interactive identities are tried before any other identities.
//
// If no identity matches the encrypted file, the returned error will be of type
// [NoIdentityMatchError].
// Decrypt returns a Reader reading the decrypted plaintext of the age file read
// from src. If no identity matches the encrypted file, the returned error will
// be of type [NoIdentityMatchError].
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
hdr, payload, err := format.Parse(src)
if err != nil {
@@ -278,6 +277,58 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
}
// DecryptReaderAt decrypts a file encrypted to one or more identities.
// All identities will be tried until one successfully decrypts the file.
// Native, non-interactive identities are tried before any other identities.
//
// DecryptReaderAt takes an underlying [io.ReaderAt] and its total encrypted
// size, and returns a ReaderAt of the decrypted plaintext and the plaintext
// size. These can be used for example to instantiate an [io.SectionReader],
// which implements [io.Reader] and [io.Seeker]. Note that ReaderAt by
// definition disregards the seek position of src.
//
// The ReadAt method of the returned ReaderAt can be called concurrently.
// The ReaderAt will internally cache the most recently decrypted chunk.
// DecryptReaderAt reads and decrypts the final chunk before returning,
// to authenticate the plaintext size.
//
// If no identity matches the encrypted file, the returned error will be of
// type [NoIdentityMatchError].
func DecryptReaderAt(src io.ReaderAt, encryptedSize int64, identities ...Identity) (io.ReaderAt, int64, error) {
srcReader := io.NewSectionReader(src, 0, encryptedSize)
hdr, payload, err := format.Parse(srcReader)
if err != nil {
return nil, 0, fmt.Errorf("failed to read header: %w", err)
}
buf := &bytes.Buffer{}
if err := hdr.Marshal(buf); err != nil {
return nil, 0, fmt.Errorf("failed to serialize header: %w", err)
}
fileKey, err := decryptHdr(hdr, identities...)
if err != nil {
return nil, 0, err
}
nonce := make([]byte, streamNonceSize)
if _, err := io.ReadFull(payload, nonce); err != nil {
return nil, 0, fmt.Errorf("failed to read nonce: %w", err)
}
payloadOffset := int64(buf.Len()) + int64(len(nonce))
payloadSize := encryptedSize - payloadOffset
plaintextSize, err := stream.PlaintextSize(payloadSize)
if err != nil {
return nil, 0, err
}
payloadReaderAt := io.NewSectionReader(src, payloadOffset, payloadSize)
r, err := stream.NewDecryptReaderAt(streamKey(fileKey, nonce), payloadReaderAt, payloadSize)
if err != nil {
return nil, 0, err
}
return r, plaintextSize, nil
}
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
if len(identities) == 0 {
return nil, errors.New("no identities specified")

View File

@@ -10,7 +10,6 @@ import (
"filippo.io/age/armor"
"filippo.io/age/internal/format"
"filippo.io/age/internal/stream"
"golang.org/x/crypto/chacha20poly1305"
)
type Metadata struct {
@@ -88,9 +87,9 @@ func Inspect(r io.Reader, fileSize int64) (*Metadata, error) {
}
data.Sizes.Armor = tr.count - fileSize
}
data.Sizes.Overhead = streamOverhead(fileSize - data.Sizes.Header)
if data.Sizes.Overhead > fileSize-data.Sizes.Header {
return nil, fmt.Errorf("payload too small to be a valid age file")
data.Sizes.Overhead, err = streamOverhead(fileSize - data.Sizes.Header)
if err != nil {
return nil, fmt.Errorf("failed to compute stream overhead: %w", err)
}
data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead
data.Sizes.MaxPayload = data.Sizes.MinPayload
@@ -114,13 +113,15 @@ func (tr *trackReader) Read(p []byte) (int, error) {
return n, err
}
func streamOverhead(payloadSize int64) int64 {
func streamOverhead(payloadSize int64) (int64, error) {
const streamNonceSize = 16
const encChunkSize = stream.ChunkSize + chacha20poly1305.Overhead
payloadSize -= streamNonceSize
if payloadSize <= 0 {
return streamNonceSize
if payloadSize < streamNonceSize {
return 0, fmt.Errorf("encrypted size too small: %d", payloadSize)
}
chunks := (payloadSize + encChunkSize - 1) / encChunkSize
return streamNonceSize + chunks*chacha20poly1305.Overhead
encryptedSize := payloadSize - streamNonceSize
plaintextSize, err := stream.PlaintextSize(encryptedSize)
if err != nil {
return 0, err
}
return payloadSize - plaintextSize, nil
}

View File

@@ -0,0 +1,46 @@
package inspect
import (
"fmt"
"testing"
"filippo.io/age/internal/stream"
)
func TestStreamOverhead(t *testing.T) {
tests := []struct {
payloadSize int64
want int64
wantErr bool
}{
{payloadSize: 0, wantErr: true},
{payloadSize: 15, wantErr: true},
{payloadSize: 16, wantErr: true},
{payloadSize: 16 + 15, wantErr: true},
{payloadSize: 16 + 16, want: 16 + 16}, // empty plaintext
{payloadSize: 16 + 1 + 16, want: 16 + 16},
{payloadSize: 16 + stream.ChunkSize + 16, want: 16 + 16},
{payloadSize: 16 + stream.ChunkSize + 16 + 1, wantErr: true},
{payloadSize: 16 + stream.ChunkSize + 16 + 15, wantErr: true},
{payloadSize: 16 + stream.ChunkSize + 16 + 16, wantErr: true}, // empty final chunk
{payloadSize: 16 + stream.ChunkSize + 16 + 1 + 16, want: 16 + 16 + 16},
}
for _, tt := range tests {
name := "payloadSize=" + fmt.Sprint(tt.payloadSize)
t.Run(name, func(t *testing.T) {
got, gotErr := streamOverhead(tt.payloadSize)
if gotErr != nil {
if !tt.wantErr {
t.Errorf("streamOverhead() failed: %v", gotErr)
}
return
}
if tt.wantErr {
t.Fatal("streamOverhead() succeeded unexpectedly")
}
if got != tt.want {
t.Errorf("streamOverhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -8,15 +8,42 @@ package stream
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"io"
"sync/atomic"
"golang.org/x/crypto/chacha20poly1305"
)
const ChunkSize = 64 * 1024
func EncryptedChunkCount(encryptedSize int64) (int64, error) {
chunks := (encryptedSize + encChunkSize - 1) / encChunkSize
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize
// Empty plaintext, the only case that allows (and requires) an empty chunk.
if plaintextSize == 0 {
expChunks = 1
}
if expChunks != chunks {
return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize)
}
return chunks, nil
}
func PlaintextSize(encryptedSize int64) (int64, error) {
chunks, err := EncryptedChunkCount(encryptedSize)
if err != nil {
return 0, err
}
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
return plaintextSize, nil
}
type DecryptReader struct {
a cipher.AEAD
src io.Reader
@@ -135,6 +162,12 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
panic("stream: chunk counter wrapped around")
}
func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte {
var nonce [chacha20poly1305.NonceSize]byte
binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex))
return &nonce
}
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
nonce[len(nonce)-1] = lastChunkFlag
}
@@ -312,3 +345,102 @@ func (r *EncryptReader) feedBuffer() error {
return nil
}
type DecryptReaderAt struct {
a cipher.AEAD
src io.ReaderAt
size int64
chunks int64
cache atomic.Pointer[cachedChunk]
}
type cachedChunk struct {
off int64
data []byte
}
func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
// Check that size is valid by decrypting the final chunk.
chunks, err := EncryptedChunkCount(size)
if err != nil {
return nil, err
}
finalChunkIndex := chunks - 1
finalChunkOff := finalChunkIndex * encChunkSize
finalChunkSize := size - finalChunkOff
finalChunk := make([]byte, finalChunkSize)
if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil {
return nil, fmt.Errorf("failed to read final chunk: %w", err)
}
nonce := nonceForChunk(finalChunkIndex)
setLastChunkFlag(nonce)
plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err)
}
cache := &cachedChunk{off: finalChunkOff, data: plaintext}
plaintextSize := size - chunks*chacha20poly1305.Overhead
r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks}
r.cache.Store(cache)
return r, nil
}
func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
if off < 0 || off > r.size {
return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off)
}
if len(p) == 0 {
return 0, nil
}
chunk := make([]byte, encChunkSize)
for len(p) > 0 && off < r.size {
chunkIndex := off / ChunkSize
chunkOff := chunkIndex * encChunkSize
encSize := r.size + r.chunks*chacha20poly1305.Overhead
chunkSize := min(encSize-chunkOff, encChunkSize)
cached := r.cache.Load()
var plaintext []byte
if cached != nil && cached.off == chunkOff {
plaintext = cached.data
} else {
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
if err == io.EOF {
if int64(nn) != chunkSize {
err = io.ErrUnexpectedEOF
} else {
err = nil
}
}
if err != nil {
return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err)
}
nonce := nonceForChunk(chunkIndex)
if chunkIndex == r.chunks-1 {
setLastChunkFlag(nonce)
}
plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil)
if err != nil {
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
}
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
}
plainChunkOff := int(off - chunkIndex*ChunkSize)
copySize := min(len(plaintext)-plainChunkOff, len(p))
copy(p, plaintext[plainChunkOff:plainChunkOff+copySize])
p = p[copySize:]
off += int64(copySize)
n += copySize
}
if off == r.size {
return n, io.EOF
}
return n, nil
}

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"testing"
"testing/iotest"
"filippo.io/age/internal/stream"
"golang.org/x/crypto/chacha20poly1305"
@@ -20,13 +21,16 @@ const cs = stream.ChunkSize
func TestRoundTrip(t *testing.T) {
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
testRoundTrip(t, stepSize, length)
})
}
}
length, stepSize := 2*cs+500, 1
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
testRoundTrip(t, stepSize, length)
})
}
func testRoundTrip(t *testing.T, stepSize, length int) {
@@ -34,85 +38,753 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
if _, err := rand.Read(src); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
var ciphertext []byte
t.Run("EncryptWriter", func(t *testing.T) {
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
var n int
for n < length {
b := min(length-n, stepSize)
nn, err := w.Write(src[n : n+b])
if err != nil {
t.Fatal(err)
}
if nn != b {
t.Errorf("Write returned %d, expected %d", nn, b)
}
n += nn
nn, err = w.Write(src[n:n])
if err != nil {
t.Fatal(err)
}
if nn != 0 {
t.Errorf("Write returned %d, expected 0", nn)
}
}
if err := w.Close(); err != nil {
t.Error("Close returned an error:", err)
}
ciphertext = buf.Bytes()
})
t.Run("DecryptReader", func(t *testing.T) {
r, err := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
if err != nil {
t.Fatal(err)
}
var n int
readBuf := make([]byte, stepSize)
for n < length {
nn, err := r.Read(readBuf)
if err != nil {
t.Fatalf("Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
r, _ := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
if err := iotest.TestReader(r, src); err != nil {
t.Error("iotest.TestReader error on DecryptReader:", err)
}
})
})
t.Run("DecryptReaderAt", func(t *testing.T) {
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
rr := io.NewSectionReader(rAt, 0, int64(len(ciphertext)))
var n int
readBuf := make([]byte, stepSize)
for n < length {
nn, err := rr.Read(readBuf)
if n+nn == length && err == io.EOF {
err = nil
}
if err != nil {
t.Fatalf("ReadAt error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
rr := io.NewSectionReader(rAt, 0, int64(len(src)))
if err := iotest.TestReader(rr, src); err != nil {
t.Error("iotest.TestReader error on DecryptReaderAt:", err)
}
})
})
t.Run("EncryptReader", func(t *testing.T) {
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
if err != nil {
t.Fatal(err)
}
var n int
readBuf := make([]byte, stepSize)
for {
nn, err := er.Read(readBuf)
if nn == 0 && err == io.EOF {
break
} else if err != nil {
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
if n != len(ciphertext) {
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
er, _ := stream.NewEncryptReader(key, bytes.NewReader(src))
if err := iotest.TestReader(er, ciphertext); err != nil {
t.Error("iotest.TestReader error on EncryptReader:", err)
}
})
})
}
// trackingReaderAt wraps an io.ReaderAt and tracks whether ReadAt was called.
type trackingReaderAt struct {
r io.ReaderAt
called bool
}
func (t *trackingReaderAt) ReadAt(p []byte, off int64) (int, error) {
t.called = true
return t.r.ReadAt(p, off)
}
func (t *trackingReaderAt) reset() {
t.called = false
}
func TestDecryptReaderAt(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create plaintext spanning exactly 3 chunks: 2 full chunks + partial third
// Chunk 0: [0, cs)
// Chunk 1: [cs, 2*cs)
// Chunk 2: [2*cs, 2*cs+500)
plaintextSize := 2*cs + 500
plaintext := make([]byte, plaintextSize)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
// Encrypt
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
var n int
for n < length {
b := min(length-n, stepSize)
nn, err := w.Write(src[n : n+b])
if err != nil {
t.Fatal(err)
}
if nn != b {
t.Errorf("Write returned %d, expected %d", nn, b)
}
n += nn
nn, err = w.Write(src[n:n])
if err != nil {
t.Fatal(err)
}
if nn != 0 {
t.Errorf("Write returned %d, expected 0", nn)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Error("Close returned an error:", err)
t.Fatal(err)
}
ciphertext := buf.Bytes()
t.Logf("buffer size: %d", buf.Len())
ciphertext := bytes.Clone(buf.Bytes())
// Create tracking ReaderAt
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
r, err := stream.NewDecryptReader(key, buf)
// Create DecryptReaderAt (this reads and caches the final chunk)
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
n = 0
readBuf := make([]byte, stepSize)
for n < length {
nn, err := r.Read(readBuf)
if err != nil {
t.Fatalf("Read error at index %d: %v", n, err)
// Helper to check reads
checkRead := func(name string, off int64, size int, wantN int, wantEOF bool, wantSrcRead bool) {
t.Helper()
tracker.reset()
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if wantEOF {
if err != io.EOF {
t.Errorf("%s: got err=%v, want EOF", name, err)
}
} else {
if err != nil {
t.Errorf("%s: got err=%v, want nil", name, err)
}
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
if n != wantN {
t.Errorf("%s: got n=%d, want %d", name, n, wantN)
}
n += nn
if tracker.called != wantSrcRead {
t.Errorf("%s: src.ReadAt called=%v, want %v", name, tracker.called, wantSrcRead)
}
// Verify data correctness
if n > 0 && off >= 0 && off < int64(plaintextSize) {
end := int(off) + n
if end > plaintextSize {
end = plaintextSize
}
if !bytes.Equal(p[:n], plaintext[off:end]) {
t.Errorf("%s: data mismatch", name)
}
}
}
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
// Test 1: Read from final chunk (cached by constructor)
checkRead("final chunk (cached)", int64(2*cs+100), 100, 100, false, false)
// Test 2: Read spanning second and third chunk
checkRead("span chunks 1-2", int64(cs+cs-50), 100, 100, false, true)
// Test 3: Read from final chunk again (cached from test 2)
// When reading across chunks 1-2 in test 2, the loop processes chunk 1 then chunk 2,
// so chunk 2 ends up in the cache.
checkRead("final chunk after span", int64(2*cs+200), 100, 100, false, false)
// Test 4: Read from final chunk again (now cached)
checkRead("final chunk (cached again)", int64(2*cs+50), 50, 50, false, false)
// Test 5: Read from first chunk (not cached)
checkRead("first chunk", 0, 100, 100, false, true)
// Test 6: Read from first chunk again (now cached)
checkRead("first chunk (cached)", 50, 100, 100, false, false)
// Test 7: Read spanning all chunks
tracker.reset()
p := make([]byte, plaintextSize)
n, err := ra.ReadAt(p, 0)
if err != io.EOF {
t.Errorf("span all: got err=%v, want EOF", err)
}
if n != plaintextSize {
t.Errorf("span all: got n=%d, want %d", n, plaintextSize)
}
if !bytes.Equal(p, plaintext) {
t.Errorf("span all: data mismatch")
}
// Test 8: Read beyond the end (offset > size)
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(plaintextSize+100))
if err == nil {
t.Error("beyond end: expected error, got nil")
}
if n != 0 {
t.Errorf("beyond end: got n=%d, want 0", n)
}
// Test 9: Read with off = size (should return 0, EOF)
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(plaintextSize))
if err != io.EOF {
t.Errorf("off=size: got err=%v, want EOF", err)
}
if n != 0 {
t.Errorf("off=size: got n=%d, want 0", n)
}
// Test 10: Read spanning last chunk and beyond
tracker.reset()
p = make([]byte, 1000) // request more than available
n, err = ra.ReadAt(p, int64(2*cs+400))
if err != io.EOF {
t.Errorf("span last+beyond: got err=%v, want EOF", err)
}
wantN := 500 - 400 // only 100 bytes available from offset 2*cs+400
if n != wantN {
t.Errorf("span last+beyond: got n=%d, want %d", n, wantN)
}
if !bytes.Equal(p[:n], plaintext[2*cs+400:]) {
t.Error("span last+beyond: data mismatch")
}
// Test 11: Read spanning second+last chunk and beyond
tracker.reset()
p = make([]byte, cs+1000) // request more than available
n, err = ra.ReadAt(p, int64(cs+100))
if err != io.EOF {
t.Errorf("span 1-2+beyond: got err=%v, want EOF", err)
}
wantN = plaintextSize - (cs + 100)
if n != wantN {
t.Errorf("span 1-2+beyond: got n=%d, want %d", n, wantN)
}
if !bytes.Equal(p[:n], plaintext[cs+100:]) {
t.Error("span 1-2+beyond: data mismatch")
}
// Test 12: Negative offset
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, -1)
if err == nil {
t.Error("negative offset: expected error, got nil")
}
if n != 0 {
t.Errorf("negative offset: got n=%d, want 0", n)
}
// Test 13: Zero-length read in the middle
tracker.reset()
p = make([]byte, 0)
n, err = ra.ReadAt(p, 100)
if err != nil {
t.Errorf("zero-length middle: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("zero-length middle: got n=%d, want 0", n)
}
// Test 14: Zero-length read at end
tracker.reset()
p = make([]byte, 0)
n, err = ra.ReadAt(p, int64(plaintextSize))
if err != nil {
t.Errorf("zero-length end: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("zero-length end: got n=%d, want 0", n)
}
// Test 15: Read exactly one chunk at chunk boundary
checkRead("exact chunk at boundary", int64(cs), cs, cs, false, true)
// Test 16: Read one byte at each chunk boundary
checkRead("one byte at start", 0, 1, 1, false, true)
checkRead("one byte at cs-1", int64(cs-1), 1, 1, false, false) // cached from test 15
checkRead("one byte at cs", int64(cs), 1, 1, false, true)
checkRead("one byte at 2*cs-1", int64(2*cs-1), 1, 1, false, false) // same chunk
checkRead("one byte at 2*cs", int64(2*cs), 1, 1, false, true)
checkRead("last byte", int64(plaintextSize-1), 1, 1, true, false) // same chunk, EOF because we reach end
// Test 17: Read crossing exactly one chunk boundary
checkRead("cross boundary 0-1", int64(cs-50), 100, 100, false, true)
checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true)
}
func TestDecryptReaderAtEmpty(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create empty encrypted file
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
n = 0
for {
nn, err := er.Read(readBuf)
if nn == 0 && err == io.EOF {
break
} else if err != nil {
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
}
n += nn
if err := w.Close(); err != nil {
t.Fatal(err)
}
if n != len(ciphertext) {
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// Test 1: Read from empty file at offset 0
p := make([]byte, 100)
n, err := ra.ReadAt(p, 0)
if err != io.EOF {
t.Errorf("empty read: got err=%v, want EOF", err)
}
if n != 0 {
t.Errorf("empty read: got n=%d, want 0", n)
}
// Test 2: Zero-length read from empty file
p = make([]byte, 0)
n, err = ra.ReadAt(p, 0)
if err != nil {
t.Errorf("empty zero-length: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("empty zero-length: got n=%d, want 0", n)
}
// Test 3: Read beyond empty file
p = make([]byte, 100)
n, err = ra.ReadAt(p, 1)
if err == nil {
t.Error("empty beyond: expected error, got nil")
}
if n != 0 {
t.Errorf("empty beyond: got n=%d, want 0", n)
}
}
func TestDecryptReaderAtSingleChunk(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Single chunk, not full
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// All reads should use cache (final chunk = only chunk)
p := make([]byte, 100)
n, err := ra.ReadAt(p, 0)
if err != nil {
t.Errorf("single chunk start: got err=%v, want nil", err)
}
if n != 100 {
t.Errorf("single chunk start: got n=%d, want 100", n)
}
if tracker.called {
t.Error("single chunk start: unexpected src.ReadAt call")
}
if !bytes.Equal(p[:n], plaintext[:100]) {
t.Error("single chunk start: data mismatch")
}
// Read at end
n, err = ra.ReadAt(p, 900)
if err != io.EOF {
t.Errorf("single chunk end: got err=%v, want EOF", err)
}
if n != 100 {
t.Errorf("single chunk end: got n=%d, want 100", n)
}
if tracker.called {
t.Error("single chunk end: unexpected src.ReadAt call")
}
if !bytes.Equal(p[:n], plaintext[900:]) {
t.Error("single chunk end: data mismatch")
}
}
func TestDecryptReaderAtFullChunks(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Exactly 2 full chunks
plaintext := make([]byte, 2*cs)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// Read last byte of second chunk (cached)
p := make([]byte, 1)
n, err := ra.ReadAt(p, int64(2*cs-1))
if err != io.EOF {
t.Errorf("last byte: got err=%v, want EOF", err)
}
if n != 1 {
t.Errorf("last byte: got n=%d, want 1", n)
}
if tracker.called {
t.Error("last byte: unexpected src.ReadAt call (should be cached)")
}
if p[0] != plaintext[2*cs-1] {
t.Error("last byte: data mismatch")
}
// Read at exactly the boundary between chunks
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(cs-50))
if err != nil {
t.Errorf("boundary: got err=%v, want nil", err)
}
if n != 100 {
t.Errorf("boundary: got n=%d, want 100", n)
}
if !bytes.Equal(p, plaintext[cs-50:cs+50]) {
t.Error("boundary: data mismatch")
}
}
func TestDecryptReaderAtWrongKey(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Try to decrypt with wrong key
wrongKey := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(wrongKey); err != nil {
t.Fatal(err)
}
_, err = stream.NewDecryptReaderAt(wrongKey, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err == nil {
t.Error("wrong key: expected error, got nil")
}
}
func TestDecryptReaderAtInvalidSize(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Wrong size (too small)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)-1))
if err == nil {
t.Error("wrong size (small): expected error, got nil")
}
// Wrong size (too large)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)+1))
if err == nil {
t.Error("wrong size (large): expected error, got nil")
}
// Size that would imply empty final chunk (invalid)
// This would be: one full encrypted chunk + just overhead
invalidSize := int64(cs + chacha20poly1305.Overhead + chacha20poly1305.Overhead)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(make([]byte, invalidSize)), invalidSize)
if err == nil {
t.Error("invalid size (empty final chunk): expected error, got nil")
}
}
func TestDecryptReaderAtTruncated(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 2*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Truncate ciphertext but lie about size
truncated := ciphertext[:len(ciphertext)-100]
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), int64(len(ciphertext)))
if err == nil {
t.Error("truncated: expected error, got nil")
}
}
func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create 4 chunks: 3 full + 1 partial
plaintext := make([]byte, 3*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Truncate to 3 chunks (remove the actual final chunk)
// The third chunk was NOT encrypted with the last chunk flag,
// so decryption should fail when we try to use it as the final chunk.
encChunkSize := cs + 16 // ChunkSize + Overhead
truncatedSize := int64(3 * encChunkSize)
truncated := ciphertext[:truncatedSize]
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), truncatedSize)
if err == nil {
t.Error("truncated at chunk boundary: expected error, got nil")
}
}
func TestDecryptReaderAtCorrupted(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 2*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := bytes.Clone(buf.Bytes())
// Corrupt final chunk - should fail in constructor
corruptedFinal := bytes.Clone(ciphertext)
corruptedFinal[len(corruptedFinal)-10] ^= 0xFF
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFinal), int64(len(corruptedFinal)))
if err == nil {
t.Error("corrupted final: expected error, got nil")
}
// Corrupt first chunk - should fail on read
corruptedFirst := bytes.Clone(ciphertext)
corruptedFirst[10] ^= 0xFF
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFirst), int64(len(corruptedFirst)))
if err != nil {
t.Fatalf("corrupted first constructor: unexpected error: %v", err)
}
p := make([]byte, 100)
_, err = ra.ReadAt(p, 0)
if err == nil {
t.Error("corrupted first read: expected error, got nil")
}
}

View File

@@ -140,10 +140,16 @@ func parseVector(t *testing.T, test []byte) *vector {
}
func TestVectors(t *testing.T) {
forEachVector(t, testVector)
forEachVector(t, func(t *testing.T, v *vector) {
var plaintext []byte
t.Run("Decrypt", func(t *testing.T) { plaintext = testDecrypt(t, v) })
t.Run("DecryptReaderAt", func(t *testing.T) { testDecryptReaderAt(t, v, plaintext) })
t.Run("Inspect", func(t *testing.T) { testInspect(t, v, plaintext) })
t.Run("RoundTrip", func(t *testing.T) { testVectorRoundTrip(t, v) })
})
}
func testVector(t *testing.T, v *vector) {
func testDecrypt(t *testing.T, v *vector) []byte {
var in io.Reader = bytes.NewReader(v.file)
if v.armored {
in = armor.NewReader(in)
@@ -152,25 +158,25 @@ func testVector(t *testing.T, v *vector) {
if err != nil && strings.HasSuffix(err.Error(), "bad header MAC") {
if v.expect == "HMAC failure" {
t.Log(err)
return
return nil
}
t.Fatalf("expected %s, got HMAC error", v.expect)
} else if e := new(armor.Error); errors.As(err, &e) {
if v.expect == "armor failure" {
t.Log(err)
return
return nil
}
t.Fatalf("expected %s, got: %v", v.expect, err)
} else if _, ok := err.(*age.NoIdentityMatchError); ok {
if v.expect == "no match" {
t.Log(err)
return
return nil
}
t.Fatalf("expected %s, got: %v", v.expect, err)
} else if err != nil {
if v.expect == "header failure" {
t.Log(err)
return
return nil
}
t.Fatalf("expected %s, got: %v", v.expect, err)
} else if v.expect != "success" && v.expect != "payload failure" &&
@@ -188,15 +194,77 @@ func testVector(t *testing.T, v *vector) {
}
}
if v.payloadHash != nil && sha256.Sum256(out) != *v.payloadHash {
t.Error("partial payload hash mismatch")
t.Errorf("partial payload hash mismatch, read %d bytes", len(out))
}
return
return out
} else if v.expect != "success" {
t.Fatalf("expected %s, got success", v.expect)
}
if sha256.Sum256(out) != *v.payloadHash {
t.Error("payload hash mismatch")
}
return out
}
func testDecryptReaderAt(t *testing.T, v *vector, plaintext []byte) {
if v.armored {
t.Skip("armor.NewReader does not implement ReaderAt")
}
rAt, s, err := age.DecryptReaderAt(bytes.NewReader(v.file), int64(len(v.file)), v.identities...)
switch v.expect {
case "success":
if err != nil {
t.Fatalf("expected success, got: %v", err)
}
if int64(len(plaintext)) != s {
t.Errorf("unexpected size: got %d, want %d", s, len(plaintext))
}
case "payload failure":
// DecryptReaderAt detects some (but not all) payload failures upfront,
// either from the size of the payload, or by decrypting the last chunk
// to authenticate its size.
if err != nil {
t.Log(err)
return
}
default:
if err != nil {
t.Log(err)
return
}
t.Fatalf("expected %s, got success", v.expect)
}
out, err := io.ReadAll(io.NewSectionReader(rAt, 0, s))
if v.expect == "success" {
if err != nil {
t.Fatalf("expected success, got: %v", err)
}
} else {
if err == nil {
t.Fatalf("expected %s, got success", v.expect)
}
t.Log(err)
// We can't check the partial payload hash, because the ReaderAt will
// notice errors that a linearly scanning Reader could not. For example,
// if there are two final chunks, the linear Reader will decrypt the
// first one and then error out on the second, while the ReaderAt will
// decrypt the second one to check the size, and then know that the
// first chunk could not be the last one. Instead, check that the
// prefix, if any, matches.
if !bytes.HasPrefix(plaintext, out) {
t.Errorf("partial payload prefix mismatch, read %d bytes", len(out))
}
return
}
if sha256.Sum256(out) != *v.payloadHash {
t.Error("payload hash mismatch")
}
}
func testInspect(t *testing.T, v *vector, plaintext []byte) {
if v.expect != "success" {
t.Skip("invalid file, can't inspect")
}
for _, fileSize := range []int64{int64(len(v.file)), -1} {
metadata, err := inspect.Inspect(bytes.NewReader(v.file), fileSize)
if err != nil {
@@ -211,8 +279,8 @@ func testVector(t *testing.T, v *vector) {
if metadata.Sizes.Armor+metadata.Sizes.Header+metadata.Sizes.Overhead+metadata.Sizes.MinPayload != int64(len(v.file)) {
t.Errorf("size breakdown does not add up to file size")
}
if metadata.Sizes.MinPayload != int64(len(out)) {
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(out))
if metadata.Sizes.MinPayload != int64(len(plaintext)) {
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(plaintext))
}
if metadata.Sizes.MaxPayload != metadata.Sizes.MinPayload {
t.Errorf("unexpected max payload size: got %d, want %d", metadata.Sizes.MaxPayload, metadata.Sizes.MinPayload)
@@ -223,16 +291,12 @@ func testVector(t *testing.T, v *vector) {
}
}
// TestVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
// testVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
// payload in the test vectors re-encodes identically.
func TestVectorsRoundTrip(t *testing.T) {
forEachVector(t, testVectorRoundTrip)
}
func testVectorRoundTrip(t *testing.T, v *vector) {
if v.armored {
if v.expect == "armor failure" {
t.SkipNow()
t.Skip("invalid armor, nothing to round-trip")
}
t.Run("armor", func(t *testing.T) {
payload, err := io.ReadAll(armor.NewReader(bytes.NewReader(v.file)))
@@ -261,7 +325,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
}
if v.expect == "header failure" {
t.SkipNow()
t.Skip("invalid header, nothing to round-trip")
}
hdr, p, err := format.Parse(bytes.NewReader(v.file))
if err != nil {
@@ -283,46 +347,62 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
}
})
if v.expect == "success" {
t.Run("STREAM", func(t *testing.T) {
nonce, payload := payload[:16], payload[16:]
key := streamKey(v.fileKey[:], nonce)
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
if err != nil {
t.Fatal(err)
}
plaintext, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), payload) {
t.Error("got a different STREAM ciphertext")
}
buf.Reset()
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
if err != nil {
t.Fatal(err)
}
ciphertext, err := io.ReadAll(er)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, payload) {
t.Error("got a different STREAM ciphertext from EncryptReader")
}
})
if v.expect != "success" {
return
}
t.Run("STREAM", func(t *testing.T) {
nonce, payload := payload[:16], payload[16:]
key := streamKey(v.fileKey[:], nonce)
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
if err != nil {
t.Fatal(err)
}
plaintext, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(payload), int64(len(payload)))
if err != nil {
t.Fatal(err)
}
plaintextAt, err := io.ReadAll(io.NewSectionReader(rAt, 0, int64(len(plaintext))))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(plaintextAt, plaintext) {
t.Errorf("got a different plaintext from DecryptReaderAt")
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf.Bytes(), payload) {
t.Error("got a different STREAM ciphertext")
}
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
if err != nil {
t.Fatal(err)
}
ciphertext, err := io.ReadAll(er)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, payload) {
t.Error("got a different STREAM ciphertext from EncryptReader")
}
})
}
func streamKey(fileKey, nonce []byte) []byte {