diff --git a/age.go b/age.go index f255a51..16d9bf6 100644 --- a/age.go +++ b/age.go @@ -252,13 +252,12 @@ func (e *NoIdentityMatchError) Unwrap() []error { } // Decrypt decrypts a file encrypted to one or more identities. -// -// It returns a Reader reading the decrypted plaintext of the age file read -// from src. All identities will be tried until one successfully decrypts the file. +// All identities will be tried until one successfully decrypts the file. // Native, non-interactive identities are tried before any other identities. // -// If no identity matches the encrypted file, the returned error will be of type -// [NoIdentityMatchError]. +// Decrypt returns a Reader reading the decrypted plaintext of the age file read +// from src. If no identity matches the encrypted file, the returned error will +// be of type [NoIdentityMatchError]. func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { hdr, payload, err := format.Parse(src) if err != nil { @@ -278,6 +277,58 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) { return stream.NewDecryptReader(streamKey(fileKey, nonce), payload) } +// DecryptReaderAt decrypts a file encrypted to one or more identities. +// All identities will be tried until one successfully decrypts the file. +// Native, non-interactive identities are tried before any other identities. +// +// DecryptReaderAt takes an underlying [io.ReaderAt] and its total encrypted +// size, and returns a ReaderAt of the decrypted plaintext and the plaintext +// size. These can be used for example to instantiate an [io.SectionReader], +// which implements [io.Reader] and [io.Seeker]. Note that ReaderAt by +// definition disregards the seek position of src. +// +// The ReadAt method of the returned ReaderAt can be called concurrently. +// The ReaderAt will internally cache the most recently decrypted chunk. +// DecryptReaderAt reads and decrypts the final chunk before returning, +// to authenticate the plaintext size. +// +// If no identity matches the encrypted file, the returned error will be of +// type [NoIdentityMatchError]. +func DecryptReaderAt(src io.ReaderAt, encryptedSize int64, identities ...Identity) (io.ReaderAt, int64, error) { + srcReader := io.NewSectionReader(src, 0, encryptedSize) + hdr, payload, err := format.Parse(srcReader) + if err != nil { + return nil, 0, fmt.Errorf("failed to read header: %w", err) + } + buf := &bytes.Buffer{} + if err := hdr.Marshal(buf); err != nil { + return nil, 0, fmt.Errorf("failed to serialize header: %w", err) + } + + fileKey, err := decryptHdr(hdr, identities...) + if err != nil { + return nil, 0, err + } + + nonce := make([]byte, streamNonceSize) + if _, err := io.ReadFull(payload, nonce); err != nil { + return nil, 0, fmt.Errorf("failed to read nonce: %w", err) + } + + payloadOffset := int64(buf.Len()) + int64(len(nonce)) + payloadSize := encryptedSize - payloadOffset + plaintextSize, err := stream.PlaintextSize(payloadSize) + if err != nil { + return nil, 0, err + } + payloadReaderAt := io.NewSectionReader(src, payloadOffset, payloadSize) + r, err := stream.NewDecryptReaderAt(streamKey(fileKey, nonce), payloadReaderAt, payloadSize) + if err != nil { + return nil, 0, err + } + return r, plaintextSize, nil +} + func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) { if len(identities) == 0 { return nil, errors.New("no identities specified") diff --git a/internal/inspect/inspect.go b/internal/inspect/inspect.go index abcd1e3..b0c3d49 100644 --- a/internal/inspect/inspect.go +++ b/internal/inspect/inspect.go @@ -10,7 +10,6 @@ import ( "filippo.io/age/armor" "filippo.io/age/internal/format" "filippo.io/age/internal/stream" - "golang.org/x/crypto/chacha20poly1305" ) type Metadata struct { @@ -88,9 +87,9 @@ func Inspect(r io.Reader, fileSize int64) (*Metadata, error) { } data.Sizes.Armor = tr.count - fileSize } - data.Sizes.Overhead = streamOverhead(fileSize - data.Sizes.Header) - if data.Sizes.Overhead > fileSize-data.Sizes.Header { - return nil, fmt.Errorf("payload too small to be a valid age file") + data.Sizes.Overhead, err = streamOverhead(fileSize - data.Sizes.Header) + if err != nil { + return nil, fmt.Errorf("failed to compute stream overhead: %w", err) } data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead data.Sizes.MaxPayload = data.Sizes.MinPayload @@ -114,13 +113,15 @@ func (tr *trackReader) Read(p []byte) (int, error) { return n, err } -func streamOverhead(payloadSize int64) int64 { +func streamOverhead(payloadSize int64) (int64, error) { const streamNonceSize = 16 - const encChunkSize = stream.ChunkSize + chacha20poly1305.Overhead - payloadSize -= streamNonceSize - if payloadSize <= 0 { - return streamNonceSize + if payloadSize < streamNonceSize { + return 0, fmt.Errorf("encrypted size too small: %d", payloadSize) } - chunks := (payloadSize + encChunkSize - 1) / encChunkSize - return streamNonceSize + chunks*chacha20poly1305.Overhead + encryptedSize := payloadSize - streamNonceSize + plaintextSize, err := stream.PlaintextSize(encryptedSize) + if err != nil { + return 0, err + } + return payloadSize - plaintextSize, nil } diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go new file mode 100644 index 0000000..331ff1e --- /dev/null +++ b/internal/inspect/inspect_test.go @@ -0,0 +1,46 @@ +package inspect + +import ( + "fmt" + "testing" + + "filippo.io/age/internal/stream" +) + +func TestStreamOverhead(t *testing.T) { + tests := []struct { + payloadSize int64 + want int64 + wantErr bool + }{ + {payloadSize: 0, wantErr: true}, + {payloadSize: 15, wantErr: true}, + {payloadSize: 16, wantErr: true}, + {payloadSize: 16 + 15, wantErr: true}, + {payloadSize: 16 + 16, want: 16 + 16}, // empty plaintext + {payloadSize: 16 + 1 + 16, want: 16 + 16}, + {payloadSize: 16 + stream.ChunkSize + 16, want: 16 + 16}, + {payloadSize: 16 + stream.ChunkSize + 16 + 1, wantErr: true}, + {payloadSize: 16 + stream.ChunkSize + 16 + 15, wantErr: true}, + {payloadSize: 16 + stream.ChunkSize + 16 + 16, wantErr: true}, // empty final chunk + {payloadSize: 16 + stream.ChunkSize + 16 + 1 + 16, want: 16 + 16 + 16}, + } + for _, tt := range tests { + name := "payloadSize=" + fmt.Sprint(tt.payloadSize) + t.Run(name, func(t *testing.T) { + got, gotErr := streamOverhead(tt.payloadSize) + if gotErr != nil { + if !tt.wantErr { + t.Errorf("streamOverhead() failed: %v", gotErr) + } + return + } + if tt.wantErr { + t.Fatal("streamOverhead() succeeded unexpectedly") + } + if got != tt.want { + t.Errorf("streamOverhead() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 8dfae38..731c719 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -8,15 +8,42 @@ package stream import ( "bytes" "crypto/cipher" + "encoding/binary" "errors" "fmt" "io" + "sync/atomic" "golang.org/x/crypto/chacha20poly1305" ) const ChunkSize = 64 * 1024 +func EncryptedChunkCount(encryptedSize int64) (int64, error) { + chunks := (encryptedSize + encChunkSize - 1) / encChunkSize + + plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead + expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize + // Empty plaintext, the only case that allows (and requires) an empty chunk. + if plaintextSize == 0 { + expChunks = 1 + } + if expChunks != chunks { + return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize) + } + + return chunks, nil +} + +func PlaintextSize(encryptedSize int64) (int64, error) { + chunks, err := EncryptedChunkCount(encryptedSize) + if err != nil { + return 0, err + } + plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead + return plaintextSize, nil +} + type DecryptReader struct { a cipher.AEAD src io.Reader @@ -135,6 +162,12 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) { panic("stream: chunk counter wrapped around") } +func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte { + var nonce [chacha20poly1305.NonceSize]byte + binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex)) + return &nonce +} + func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) { nonce[len(nonce)-1] = lastChunkFlag } @@ -312,3 +345,102 @@ func (r *EncryptReader) feedBuffer() error { return nil } + +type DecryptReaderAt struct { + a cipher.AEAD + src io.ReaderAt + size int64 + chunks int64 + cache atomic.Pointer[cachedChunk] +} + +type cachedChunk struct { + off int64 + data []byte +} + +func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) { + aead, err := chacha20poly1305.New(key) + if err != nil { + return nil, err + } + + // Check that size is valid by decrypting the final chunk. + chunks, err := EncryptedChunkCount(size) + if err != nil { + return nil, err + } + finalChunkIndex := chunks - 1 + finalChunkOff := finalChunkIndex * encChunkSize + finalChunkSize := size - finalChunkOff + finalChunk := make([]byte, finalChunkSize) + if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil { + return nil, fmt.Errorf("failed to read final chunk: %w", err) + } + nonce := nonceForChunk(finalChunkIndex) + setLastChunkFlag(nonce) + plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err) + } + cache := &cachedChunk{off: finalChunkOff, data: plaintext} + + plaintextSize := size - chunks*chacha20poly1305.Overhead + r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks} + r.cache.Store(cache) + return r, nil +} + +func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) { + if off < 0 || off > r.size { + return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off) + } + if len(p) == 0 { + return 0, nil + } + chunk := make([]byte, encChunkSize) + for len(p) > 0 && off < r.size { + chunkIndex := off / ChunkSize + chunkOff := chunkIndex * encChunkSize + encSize := r.size + r.chunks*chacha20poly1305.Overhead + chunkSize := min(encSize-chunkOff, encChunkSize) + + cached := r.cache.Load() + var plaintext []byte + if cached != nil && cached.off == chunkOff { + plaintext = cached.data + } else { + nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff) + if err == io.EOF { + if int64(nn) != chunkSize { + err = io.ErrUnexpectedEOF + } else { + err = nil + } + } + if err != nil { + return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err) + } + nonce := nonceForChunk(chunkIndex) + if chunkIndex == r.chunks-1 { + setLastChunkFlag(nonce) + } + plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil) + if err != nil { + return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err) + } + r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext}) + } + + plainChunkOff := int(off - chunkIndex*ChunkSize) + copySize := min(len(plaintext)-plainChunkOff, len(p)) + copy(p, plaintext[plainChunkOff:plainChunkOff+copySize]) + p = p[copySize:] + off += int64(copySize) + n += copySize + } + if off == r.size { + return n, io.EOF + } + return n, nil +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index e2f6961..1bd999f 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "testing" + "testing/iotest" "filippo.io/age/internal/stream" "golang.org/x/crypto/chacha20poly1305" @@ -20,13 +21,16 @@ const cs = stream.ChunkSize func TestRoundTrip(t *testing.T) { for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} { for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} { - t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), - func(t *testing.T) { testRoundTrip(t, stepSize, length) }) + t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) { + testRoundTrip(t, stepSize, length) + }) } } + length, stepSize := 2*cs+500, 1 - t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), - func(t *testing.T) { testRoundTrip(t, stepSize, length) }) + t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) { + testRoundTrip(t, stepSize, length) + }) } func testRoundTrip(t *testing.T, stepSize, length int) { @@ -34,85 +38,753 @@ func testRoundTrip(t *testing.T, stepSize, length int) { if _, err := rand.Read(src); err != nil { t.Fatal(err) } - buf := &bytes.Buffer{} + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + var ciphertext []byte + + t.Run("EncryptWriter", func(t *testing.T) { + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + + var n int + for n < length { + b := min(length-n, stepSize) + nn, err := w.Write(src[n : n+b]) + if err != nil { + t.Fatal(err) + } + if nn != b { + t.Errorf("Write returned %d, expected %d", nn, b) + } + n += nn + + nn, err = w.Write(src[n:n]) + if err != nil { + t.Fatal(err) + } + if nn != 0 { + t.Errorf("Write returned %d, expected 0", nn) + } + } + if err := w.Close(); err != nil { + t.Error("Close returned an error:", err) + } + + ciphertext = buf.Bytes() + }) + + t.Run("DecryptReader", func(t *testing.T) { + r, err := stream.NewDecryptReader(key, bytes.NewReader(ciphertext)) + if err != nil { + t.Fatal(err) + } + + var n int + readBuf := make([]byte, stepSize) + for n < length { + nn, err := r.Read(readBuf) + if err != nil { + t.Fatalf("Read error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], src[n:n+nn]) { + t.Errorf("wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } + + t.Run("TestReader", func(t *testing.T) { + if length > 1000 && testing.Short() { + t.Skip("skipping slow iotest.TestReader on long input") + } + r, _ := stream.NewDecryptReader(key, bytes.NewReader(ciphertext)) + if err := iotest.TestReader(r, src); err != nil { + t.Error("iotest.TestReader error on DecryptReader:", err) + } + }) + }) + + t.Run("DecryptReaderAt", func(t *testing.T) { + rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext))) + if err != nil { + t.Fatal(err) + } + rr := io.NewSectionReader(rAt, 0, int64(len(ciphertext))) + + var n int + readBuf := make([]byte, stepSize) + for n < length { + nn, err := rr.Read(readBuf) + if n+nn == length && err == io.EOF { + err = nil + } + if err != nil { + t.Fatalf("ReadAt error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], src[n:n+nn]) { + t.Errorf("wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } + + t.Run("TestReader", func(t *testing.T) { + if length > 1000 && testing.Short() { + t.Skip("skipping slow iotest.TestReader on long input") + } + rr := io.NewSectionReader(rAt, 0, int64(len(src))) + if err := iotest.TestReader(rr, src); err != nil { + t.Error("iotest.TestReader error on DecryptReaderAt:", err) + } + }) + }) + + t.Run("EncryptReader", func(t *testing.T) { + er, err := stream.NewEncryptReader(key, bytes.NewReader(src)) + if err != nil { + t.Fatal(err) + } + + var n int + readBuf := make([]byte, stepSize) + for { + nn, err := er.Read(readBuf) + if nn == 0 && err == io.EOF { + break + } else if err != nil { + t.Fatalf("EncryptReader Read error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) { + t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } + if n != len(ciphertext) { + t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext)) + } + + t.Run("TestReader", func(t *testing.T) { + if length > 1000 && testing.Short() { + t.Skip("skipping slow iotest.TestReader on long input") + } + er, _ := stream.NewEncryptReader(key, bytes.NewReader(src)) + if err := iotest.TestReader(er, ciphertext); err != nil { + t.Error("iotest.TestReader error on EncryptReader:", err) + } + }) + }) +} + +// trackingReaderAt wraps an io.ReaderAt and tracks whether ReadAt was called. +type trackingReaderAt struct { + r io.ReaderAt + called bool +} + +func (t *trackingReaderAt) ReadAt(p []byte, off int64) (int, error) { + t.called = true + return t.r.ReadAt(p, off) +} + +func (t *trackingReaderAt) reset() { + t.called = false +} + +func TestDecryptReaderAt(t *testing.T) { key := make([]byte, chacha20poly1305.KeySize) if _, err := rand.Read(key); err != nil { t.Fatal(err) } + // Create plaintext spanning exactly 3 chunks: 2 full chunks + partial third + // Chunk 0: [0, cs) + // Chunk 1: [cs, 2*cs) + // Chunk 2: [2*cs, 2*cs+500) + plaintextSize := 2*cs + 500 + plaintext := make([]byte, plaintextSize) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + // Encrypt + buf := &bytes.Buffer{} w, err := stream.NewEncryptWriter(key, buf) if err != nil { t.Fatal(err) } - - var n int - for n < length { - b := min(length-n, stepSize) - nn, err := w.Write(src[n : n+b]) - if err != nil { - t.Fatal(err) - } - if nn != b { - t.Errorf("Write returned %d, expected %d", nn, b) - } - n += nn - - nn, err = w.Write(src[n:n]) - if err != nil { - t.Fatal(err) - } - if nn != 0 { - t.Errorf("Write returned %d, expected 0", nn) - } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) } - if err := w.Close(); err != nil { - t.Error("Close returned an error:", err) + t.Fatal(err) } + ciphertext := buf.Bytes() - t.Logf("buffer size: %d", buf.Len()) - ciphertext := bytes.Clone(buf.Bytes()) + // Create tracking ReaderAt + tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)} - r, err := stream.NewDecryptReader(key, buf) + // Create DecryptReaderAt (this reads and caches the final chunk) + ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext))) if err != nil { t.Fatal(err) } + tracker.reset() - n = 0 - readBuf := make([]byte, stepSize) - for n < length { - nn, err := r.Read(readBuf) - if err != nil { - t.Fatalf("Read error at index %d: %v", n, err) + // Helper to check reads + checkRead := func(name string, off int64, size int, wantN int, wantEOF bool, wantSrcRead bool) { + t.Helper() + tracker.reset() + p := make([]byte, size) + n, err := ra.ReadAt(p, off) + + if wantEOF { + if err != io.EOF { + t.Errorf("%s: got err=%v, want EOF", name, err) + } + } else { + if err != nil { + t.Errorf("%s: got err=%v, want nil", name, err) + } } - if !bytes.Equal(readBuf[:nn], src[n:n+nn]) { - t.Errorf("wrong data at indexes %d - %d", n, n+nn) + if n != wantN { + t.Errorf("%s: got n=%d, want %d", name, n, wantN) } - n += nn + if tracker.called != wantSrcRead { + t.Errorf("%s: src.ReadAt called=%v, want %v", name, tracker.called, wantSrcRead) + } + + // Verify data correctness + if n > 0 && off >= 0 && off < int64(plaintextSize) { + end := int(off) + n + if end > plaintextSize { + end = plaintextSize + } + if !bytes.Equal(p[:n], plaintext[off:end]) { + t.Errorf("%s: data mismatch", name) + } + } } - er, err := stream.NewEncryptReader(key, bytes.NewReader(src)) + // Test 1: Read from final chunk (cached by constructor) + checkRead("final chunk (cached)", int64(2*cs+100), 100, 100, false, false) + + // Test 2: Read spanning second and third chunk + checkRead("span chunks 1-2", int64(cs+cs-50), 100, 100, false, true) + + // Test 3: Read from final chunk again (cached from test 2) + // When reading across chunks 1-2 in test 2, the loop processes chunk 1 then chunk 2, + // so chunk 2 ends up in the cache. + checkRead("final chunk after span", int64(2*cs+200), 100, 100, false, false) + + // Test 4: Read from final chunk again (now cached) + checkRead("final chunk (cached again)", int64(2*cs+50), 50, 50, false, false) + + // Test 5: Read from first chunk (not cached) + checkRead("first chunk", 0, 100, 100, false, true) + + // Test 6: Read from first chunk again (now cached) + checkRead("first chunk (cached)", 50, 100, 100, false, false) + + // Test 7: Read spanning all chunks + tracker.reset() + p := make([]byte, plaintextSize) + n, err := ra.ReadAt(p, 0) + if err != io.EOF { + t.Errorf("span all: got err=%v, want EOF", err) + } + if n != plaintextSize { + t.Errorf("span all: got n=%d, want %d", n, plaintextSize) + } + if !bytes.Equal(p, plaintext) { + t.Errorf("span all: data mismatch") + } + + // Test 8: Read beyond the end (offset > size) + tracker.reset() + p = make([]byte, 100) + n, err = ra.ReadAt(p, int64(plaintextSize+100)) + if err == nil { + t.Error("beyond end: expected error, got nil") + } + if n != 0 { + t.Errorf("beyond end: got n=%d, want 0", n) + } + + // Test 9: Read with off = size (should return 0, EOF) + tracker.reset() + p = make([]byte, 100) + n, err = ra.ReadAt(p, int64(plaintextSize)) + if err != io.EOF { + t.Errorf("off=size: got err=%v, want EOF", err) + } + if n != 0 { + t.Errorf("off=size: got n=%d, want 0", n) + } + + // Test 10: Read spanning last chunk and beyond + tracker.reset() + p = make([]byte, 1000) // request more than available + n, err = ra.ReadAt(p, int64(2*cs+400)) + if err != io.EOF { + t.Errorf("span last+beyond: got err=%v, want EOF", err) + } + wantN := 500 - 400 // only 100 bytes available from offset 2*cs+400 + if n != wantN { + t.Errorf("span last+beyond: got n=%d, want %d", n, wantN) + } + if !bytes.Equal(p[:n], plaintext[2*cs+400:]) { + t.Error("span last+beyond: data mismatch") + } + + // Test 11: Read spanning second+last chunk and beyond + tracker.reset() + p = make([]byte, cs+1000) // request more than available + n, err = ra.ReadAt(p, int64(cs+100)) + if err != io.EOF { + t.Errorf("span 1-2+beyond: got err=%v, want EOF", err) + } + wantN = plaintextSize - (cs + 100) + if n != wantN { + t.Errorf("span 1-2+beyond: got n=%d, want %d", n, wantN) + } + if !bytes.Equal(p[:n], plaintext[cs+100:]) { + t.Error("span 1-2+beyond: data mismatch") + } + + // Test 12: Negative offset + tracker.reset() + p = make([]byte, 100) + n, err = ra.ReadAt(p, -1) + if err == nil { + t.Error("negative offset: expected error, got nil") + } + if n != 0 { + t.Errorf("negative offset: got n=%d, want 0", n) + } + + // Test 13: Zero-length read in the middle + tracker.reset() + p = make([]byte, 0) + n, err = ra.ReadAt(p, 100) + if err != nil { + t.Errorf("zero-length middle: got err=%v, want nil", err) + } + if n != 0 { + t.Errorf("zero-length middle: got n=%d, want 0", n) + } + + // Test 14: Zero-length read at end + tracker.reset() + p = make([]byte, 0) + n, err = ra.ReadAt(p, int64(plaintextSize)) + if err != nil { + t.Errorf("zero-length end: got err=%v, want nil", err) + } + if n != 0 { + t.Errorf("zero-length end: got n=%d, want 0", n) + } + + // Test 15: Read exactly one chunk at chunk boundary + checkRead("exact chunk at boundary", int64(cs), cs, cs, false, true) + + // Test 16: Read one byte at each chunk boundary + checkRead("one byte at start", 0, 1, 1, false, true) + checkRead("one byte at cs-1", int64(cs-1), 1, 1, false, false) // cached from test 15 + checkRead("one byte at cs", int64(cs), 1, 1, false, true) + checkRead("one byte at 2*cs-1", int64(2*cs-1), 1, 1, false, false) // same chunk + checkRead("one byte at 2*cs", int64(2*cs), 1, 1, false, true) + checkRead("last byte", int64(plaintextSize-1), 1, 1, true, false) // same chunk, EOF because we reach end + + // Test 17: Read crossing exactly one chunk boundary + checkRead("cross boundary 0-1", int64(cs-50), 100, 100, false, true) + checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true) +} + +func TestDecryptReaderAtEmpty(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + // Create empty encrypted file + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) if err != nil { t.Fatal(err) } - n = 0 - for { - nn, err := er.Read(readBuf) - if nn == 0 && err == io.EOF { - break - } else if err != nil { - t.Fatalf("EncryptReader Read error at index %d: %v", n, err) - } - - if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) { - t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn) - } - - n += nn + if err := w.Close(); err != nil { + t.Fatal(err) } - if n != len(ciphertext) { - t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext)) + ciphertext := buf.Bytes() + + tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)} + ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext))) + if err != nil { + t.Fatal(err) + } + tracker.reset() + + // Test 1: Read from empty file at offset 0 + p := make([]byte, 100) + n, err := ra.ReadAt(p, 0) + if err != io.EOF { + t.Errorf("empty read: got err=%v, want EOF", err) + } + if n != 0 { + t.Errorf("empty read: got n=%d, want 0", n) + } + + // Test 2: Zero-length read from empty file + p = make([]byte, 0) + n, err = ra.ReadAt(p, 0) + if err != nil { + t.Errorf("empty zero-length: got err=%v, want nil", err) + } + if n != 0 { + t.Errorf("empty zero-length: got n=%d, want 0", n) + } + + // Test 3: Read beyond empty file + p = make([]byte, 100) + n, err = ra.ReadAt(p, 1) + if err == nil { + t.Error("empty beyond: expected error, got nil") + } + if n != 0 { + t.Errorf("empty beyond: got n=%d, want 0", n) + } +} + +func TestDecryptReaderAtSingleChunk(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + // Single chunk, not full + plaintext := make([]byte, 1000) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)} + ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext))) + if err != nil { + t.Fatal(err) + } + tracker.reset() + + // All reads should use cache (final chunk = only chunk) + p := make([]byte, 100) + n, err := ra.ReadAt(p, 0) + if err != nil { + t.Errorf("single chunk start: got err=%v, want nil", err) + } + if n != 100 { + t.Errorf("single chunk start: got n=%d, want 100", n) + } + if tracker.called { + t.Error("single chunk start: unexpected src.ReadAt call") + } + if !bytes.Equal(p[:n], plaintext[:100]) { + t.Error("single chunk start: data mismatch") + } + + // Read at end + n, err = ra.ReadAt(p, 900) + if err != io.EOF { + t.Errorf("single chunk end: got err=%v, want EOF", err) + } + if n != 100 { + t.Errorf("single chunk end: got n=%d, want 100", n) + } + if tracker.called { + t.Error("single chunk end: unexpected src.ReadAt call") + } + if !bytes.Equal(p[:n], plaintext[900:]) { + t.Error("single chunk end: data mismatch") + } +} + +func TestDecryptReaderAtFullChunks(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + // Exactly 2 full chunks + plaintext := make([]byte, 2*cs) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)} + ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext))) + if err != nil { + t.Fatal(err) + } + tracker.reset() + + // Read last byte of second chunk (cached) + p := make([]byte, 1) + n, err := ra.ReadAt(p, int64(2*cs-1)) + if err != io.EOF { + t.Errorf("last byte: got err=%v, want EOF", err) + } + if n != 1 { + t.Errorf("last byte: got n=%d, want 1", n) + } + if tracker.called { + t.Error("last byte: unexpected src.ReadAt call (should be cached)") + } + if p[0] != plaintext[2*cs-1] { + t.Error("last byte: data mismatch") + } + + // Read at exactly the boundary between chunks + p = make([]byte, 100) + n, err = ra.ReadAt(p, int64(cs-50)) + if err != nil { + t.Errorf("boundary: got err=%v, want nil", err) + } + if n != 100 { + t.Errorf("boundary: got n=%d, want 100", n) + } + if !bytes.Equal(p, plaintext[cs-50:cs+50]) { + t.Error("boundary: data mismatch") + } +} + +func TestDecryptReaderAtWrongKey(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := make([]byte, 1000) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + // Try to decrypt with wrong key + wrongKey := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(wrongKey); err != nil { + t.Fatal(err) + } + + _, err = stream.NewDecryptReaderAt(wrongKey, bytes.NewReader(ciphertext), int64(len(ciphertext))) + if err == nil { + t.Error("wrong key: expected error, got nil") + } +} + +func TestDecryptReaderAtInvalidSize(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := make([]byte, 1000) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + // Wrong size (too small) + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)-1)) + if err == nil { + t.Error("wrong size (small): expected error, got nil") + } + + // Wrong size (too large) + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)+1)) + if err == nil { + t.Error("wrong size (large): expected error, got nil") + } + + // Size that would imply empty final chunk (invalid) + // This would be: one full encrypted chunk + just overhead + invalidSize := int64(cs + chacha20poly1305.Overhead + chacha20poly1305.Overhead) + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(make([]byte, invalidSize)), invalidSize) + if err == nil { + t.Error("invalid size (empty final chunk): expected error, got nil") + } +} + +func TestDecryptReaderAtTruncated(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := make([]byte, 2*cs+500) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + // Truncate ciphertext but lie about size + truncated := ciphertext[:len(ciphertext)-100] + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), int64(len(ciphertext))) + if err == nil { + t.Error("truncated: expected error, got nil") + } +} + +func TestDecryptReaderAtTruncatedChunk(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + // Create 4 chunks: 3 full + 1 partial + plaintext := make([]byte, 3*cs+500) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + // Truncate to 3 chunks (remove the actual final chunk) + // The third chunk was NOT encrypted with the last chunk flag, + // so decryption should fail when we try to use it as the final chunk. + encChunkSize := cs + 16 // ChunkSize + Overhead + truncatedSize := int64(3 * encChunkSize) + truncated := ciphertext[:truncatedSize] + + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), truncatedSize) + if err == nil { + t.Error("truncated at chunk boundary: expected error, got nil") + } +} + +func TestDecryptReaderAtCorrupted(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + plaintext := make([]byte, 2*cs+500) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := bytes.Clone(buf.Bytes()) + + // Corrupt final chunk - should fail in constructor + corruptedFinal := bytes.Clone(ciphertext) + corruptedFinal[len(corruptedFinal)-10] ^= 0xFF + _, err = stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFinal), int64(len(corruptedFinal))) + if err == nil { + t.Error("corrupted final: expected error, got nil") + } + + // Corrupt first chunk - should fail on read + corruptedFirst := bytes.Clone(ciphertext) + corruptedFirst[10] ^= 0xFF + ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFirst), int64(len(corruptedFirst))) + if err != nil { + t.Fatalf("corrupted first constructor: unexpected error: %v", err) + } + p := make([]byte, 100) + _, err = ra.ReadAt(p, 0) + if err == nil { + t.Error("corrupted first read: expected error, got nil") } } diff --git a/testkit_test.go b/testkit_test.go index 292350e..b992db0 100644 --- a/testkit_test.go +++ b/testkit_test.go @@ -140,10 +140,16 @@ func parseVector(t *testing.T, test []byte) *vector { } func TestVectors(t *testing.T) { - forEachVector(t, testVector) + forEachVector(t, func(t *testing.T, v *vector) { + var plaintext []byte + t.Run("Decrypt", func(t *testing.T) { plaintext = testDecrypt(t, v) }) + t.Run("DecryptReaderAt", func(t *testing.T) { testDecryptReaderAt(t, v, plaintext) }) + t.Run("Inspect", func(t *testing.T) { testInspect(t, v, plaintext) }) + t.Run("RoundTrip", func(t *testing.T) { testVectorRoundTrip(t, v) }) + }) } -func testVector(t *testing.T, v *vector) { +func testDecrypt(t *testing.T, v *vector) []byte { var in io.Reader = bytes.NewReader(v.file) if v.armored { in = armor.NewReader(in) @@ -152,25 +158,25 @@ func testVector(t *testing.T, v *vector) { if err != nil && strings.HasSuffix(err.Error(), "bad header MAC") { if v.expect == "HMAC failure" { t.Log(err) - return + return nil } t.Fatalf("expected %s, got HMAC error", v.expect) } else if e := new(armor.Error); errors.As(err, &e) { if v.expect == "armor failure" { t.Log(err) - return + return nil } t.Fatalf("expected %s, got: %v", v.expect, err) } else if _, ok := err.(*age.NoIdentityMatchError); ok { if v.expect == "no match" { t.Log(err) - return + return nil } t.Fatalf("expected %s, got: %v", v.expect, err) } else if err != nil { if v.expect == "header failure" { t.Log(err) - return + return nil } t.Fatalf("expected %s, got: %v", v.expect, err) } else if v.expect != "success" && v.expect != "payload failure" && @@ -188,15 +194,77 @@ func testVector(t *testing.T, v *vector) { } } if v.payloadHash != nil && sha256.Sum256(out) != *v.payloadHash { - t.Error("partial payload hash mismatch") + t.Errorf("partial payload hash mismatch, read %d bytes", len(out)) } - return + return out } else if v.expect != "success" { t.Fatalf("expected %s, got success", v.expect) } if sha256.Sum256(out) != *v.payloadHash { t.Error("payload hash mismatch") } + return out +} + +func testDecryptReaderAt(t *testing.T, v *vector, plaintext []byte) { + if v.armored { + t.Skip("armor.NewReader does not implement ReaderAt") + } + rAt, s, err := age.DecryptReaderAt(bytes.NewReader(v.file), int64(len(v.file)), v.identities...) + switch v.expect { + case "success": + if err != nil { + t.Fatalf("expected success, got: %v", err) + } + if int64(len(plaintext)) != s { + t.Errorf("unexpected size: got %d, want %d", s, len(plaintext)) + } + case "payload failure": + // DecryptReaderAt detects some (but not all) payload failures upfront, + // either from the size of the payload, or by decrypting the last chunk + // to authenticate its size. + if err != nil { + t.Log(err) + return + } + default: + if err != nil { + t.Log(err) + return + } + t.Fatalf("expected %s, got success", v.expect) + } + out, err := io.ReadAll(io.NewSectionReader(rAt, 0, s)) + if v.expect == "success" { + if err != nil { + t.Fatalf("expected success, got: %v", err) + } + } else { + if err == nil { + t.Fatalf("expected %s, got success", v.expect) + } + t.Log(err) + // We can't check the partial payload hash, because the ReaderAt will + // notice errors that a linearly scanning Reader could not. For example, + // if there are two final chunks, the linear Reader will decrypt the + // first one and then error out on the second, while the ReaderAt will + // decrypt the second one to check the size, and then know that the + // first chunk could not be the last one. Instead, check that the + // prefix, if any, matches. + if !bytes.HasPrefix(plaintext, out) { + t.Errorf("partial payload prefix mismatch, read %d bytes", len(out)) + } + return + } + if sha256.Sum256(out) != *v.payloadHash { + t.Error("payload hash mismatch") + } +} + +func testInspect(t *testing.T, v *vector, plaintext []byte) { + if v.expect != "success" { + t.Skip("invalid file, can't inspect") + } for _, fileSize := range []int64{int64(len(v.file)), -1} { metadata, err := inspect.Inspect(bytes.NewReader(v.file), fileSize) if err != nil { @@ -211,8 +279,8 @@ func testVector(t *testing.T, v *vector) { if metadata.Sizes.Armor+metadata.Sizes.Header+metadata.Sizes.Overhead+metadata.Sizes.MinPayload != int64(len(v.file)) { t.Errorf("size breakdown does not add up to file size") } - if metadata.Sizes.MinPayload != int64(len(out)) { - t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(out)) + if metadata.Sizes.MinPayload != int64(len(plaintext)) { + t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(plaintext)) } if metadata.Sizes.MaxPayload != metadata.Sizes.MinPayload { t.Errorf("unexpected max payload size: got %d, want %d", metadata.Sizes.MaxPayload, metadata.Sizes.MinPayload) @@ -223,16 +291,12 @@ func testVector(t *testing.T, v *vector) { } } -// TestVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM +// testVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM // payload in the test vectors re-encodes identically. -func TestVectorsRoundTrip(t *testing.T) { - forEachVector(t, testVectorRoundTrip) -} - func testVectorRoundTrip(t *testing.T, v *vector) { if v.armored { if v.expect == "armor failure" { - t.SkipNow() + t.Skip("invalid armor, nothing to round-trip") } t.Run("armor", func(t *testing.T) { payload, err := io.ReadAll(armor.NewReader(bytes.NewReader(v.file))) @@ -261,7 +325,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) { } if v.expect == "header failure" { - t.SkipNow() + t.Skip("invalid header, nothing to round-trip") } hdr, p, err := format.Parse(bytes.NewReader(v.file)) if err != nil { @@ -283,46 +347,62 @@ func testVectorRoundTrip(t *testing.T, v *vector) { } }) - if v.expect == "success" { - t.Run("STREAM", func(t *testing.T) { - nonce, payload := payload[:16], payload[16:] - key := streamKey(v.fileKey[:], nonce) - r, err := stream.NewDecryptReader(key, bytes.NewReader(payload)) - if err != nil { - t.Fatal(err) - } - plaintext, err := io.ReadAll(r) - if err != nil { - t.Fatal(err) - } - buf := &bytes.Buffer{} - w, err := stream.NewEncryptWriter(key, buf) - if err != nil { - t.Fatal(err) - } - if _, err := w.Write(plaintext); err != nil { - t.Fatal(err) - } - if err := w.Close(); err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf.Bytes(), payload) { - t.Error("got a different STREAM ciphertext") - } - buf.Reset() - er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext)) - if err != nil { - t.Fatal(err) - } - ciphertext, err := io.ReadAll(er) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(ciphertext, payload) { - t.Error("got a different STREAM ciphertext from EncryptReader") - } - }) + if v.expect != "success" { + return } + + t.Run("STREAM", func(t *testing.T) { + nonce, payload := payload[:16], payload[16:] + key := streamKey(v.fileKey[:], nonce) + + r, err := stream.NewDecryptReader(key, bytes.NewReader(payload)) + if err != nil { + t.Fatal(err) + } + plaintext, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(payload), int64(len(payload))) + if err != nil { + t.Fatal(err) + } + plaintextAt, err := io.ReadAll(io.NewSectionReader(rAt, 0, int64(len(plaintext)))) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintextAt, plaintext) { + t.Errorf("got a different plaintext from DecryptReaderAt") + } + + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf.Bytes(), payload) { + t.Error("got a different STREAM ciphertext") + } + + er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext)) + if err != nil { + t.Fatal(err) + } + ciphertext, err := io.ReadAll(er) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(ciphertext, payload) { + t.Error("got a different STREAM ciphertext from EncryptReader") + } + }) } func streamKey(fileKey, nonce []byte) []byte {