Files
age/internal/stream/stream_test.go
2025-12-26 22:18:54 +01:00

953 lines
24 KiB
Go

// Copyright 2019 The age Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stream_test
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"testing"
"testing/iotest"
"filippo.io/age/internal/stream"
"golang.org/x/crypto/chacha20poly1305"
)
const cs = stream.ChunkSize
func TestRoundTrip(t *testing.T) {
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
testRoundTrip(t, stepSize, length)
})
}
}
length, stepSize := 2*cs+500, 1
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
testRoundTrip(t, stepSize, length)
})
}
func testRoundTrip(t *testing.T, stepSize, length int) {
src := make([]byte, length)
if _, err := rand.Read(src); err != nil {
t.Fatal(err)
}
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
var ciphertext []byte
t.Run("EncryptWriter", func(t *testing.T) {
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
var n int
for n < length {
b := min(length-n, stepSize)
nn, err := w.Write(src[n : n+b])
if err != nil {
t.Fatal(err)
}
if nn != b {
t.Errorf("Write returned %d, expected %d", nn, b)
}
n += nn
nn, err = w.Write(src[n:n])
if err != nil {
t.Fatal(err)
}
if nn != 0 {
t.Errorf("Write returned %d, expected 0", nn)
}
}
if err := w.Close(); err != nil {
t.Error("Close returned an error:", err)
}
ciphertext = buf.Bytes()
})
t.Run("DecryptReader", func(t *testing.T) {
r, err := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
if err != nil {
t.Fatal(err)
}
var n int
readBuf := make([]byte, stepSize)
for n < length {
nn, err := r.Read(readBuf)
if err != nil {
t.Fatalf("Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
r, _ := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
if err := iotest.TestReader(r, src); err != nil {
t.Error("iotest.TestReader error on DecryptReader:", err)
}
})
})
t.Run("DecryptReaderAt", func(t *testing.T) {
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
rr := io.NewSectionReader(rAt, 0, int64(len(ciphertext)))
var n int
readBuf := make([]byte, stepSize)
for n < length {
nn, err := rr.Read(readBuf)
if n+nn == length && err == io.EOF {
err = nil
}
if err != nil {
t.Fatalf("ReadAt error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
rr := io.NewSectionReader(rAt, 0, int64(len(src)))
if err := iotest.TestReader(rr, src); err != nil {
t.Error("iotest.TestReader error on DecryptReaderAt:", err)
}
})
})
t.Run("EncryptReader", func(t *testing.T) {
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
if err != nil {
t.Fatal(err)
}
var n int
readBuf := make([]byte, stepSize)
for {
nn, err := er.Read(readBuf)
if nn == 0 && err == io.EOF {
break
} else if err != nil {
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
if n != len(ciphertext) {
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
}
t.Run("TestReader", func(t *testing.T) {
if length > 1000 && testing.Short() {
t.Skip("skipping slow iotest.TestReader on long input")
}
er, _ := stream.NewEncryptReader(key, bytes.NewReader(src))
if err := iotest.TestReader(er, ciphertext); err != nil {
t.Error("iotest.TestReader error on EncryptReader:", err)
}
})
})
}
// trackingReaderAt wraps an io.ReaderAt and tracks whether ReadAt was called.
type trackingReaderAt struct {
r io.ReaderAt
called bool
}
func (t *trackingReaderAt) ReadAt(p []byte, off int64) (int, error) {
t.called = true
return t.r.ReadAt(p, off)
}
func (t *trackingReaderAt) reset() {
t.called = false
}
func TestDecryptReaderAt(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create plaintext spanning exactly 3 chunks: 2 full chunks + partial third
// Chunk 0: [0, cs)
// Chunk 1: [cs, 2*cs)
// Chunk 2: [2*cs, 2*cs+500)
plaintextSize := 2*cs + 500
plaintext := make([]byte, plaintextSize)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
// Encrypt
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Create tracking ReaderAt
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
// Create DecryptReaderAt (this reads and caches the final chunk)
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// Helper to check reads
checkRead := func(name string, off int64, size int, wantN int, wantEOF bool, wantSrcRead bool) {
t.Helper()
tracker.reset()
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if wantEOF {
if err != io.EOF {
t.Errorf("%s: got err=%v, want EOF", name, err)
}
} else {
if err != nil {
t.Errorf("%s: got err=%v, want nil", name, err)
}
}
if n != wantN {
t.Errorf("%s: got n=%d, want %d", name, n, wantN)
}
if tracker.called != wantSrcRead {
t.Errorf("%s: src.ReadAt called=%v, want %v", name, tracker.called, wantSrcRead)
}
// Verify data correctness
if n > 0 && off >= 0 && off < int64(plaintextSize) {
end := int(off) + n
if end > plaintextSize {
end = plaintextSize
}
if !bytes.Equal(p[:n], plaintext[off:end]) {
t.Errorf("%s: data mismatch", name)
}
}
}
// Test 1: Read from final chunk (cached by constructor)
checkRead("final chunk (cached)", int64(2*cs+100), 100, 100, false, false)
// Test 2: Read spanning second and third chunk
checkRead("span chunks 1-2", int64(cs+cs-50), 100, 100, false, true)
// Test 3: Read from final chunk again (cached from test 2)
// When reading across chunks 1-2 in test 2, the loop processes chunk 1 then chunk 2,
// so chunk 2 ends up in the cache.
checkRead("final chunk after span", int64(2*cs+200), 100, 100, false, false)
// Test 4: Read from final chunk again (now cached)
checkRead("final chunk (cached again)", int64(2*cs+50), 50, 50, false, false)
// Test 5: Read from first chunk (not cached)
checkRead("first chunk", 0, 100, 100, false, true)
// Test 6: Read from first chunk again (now cached)
checkRead("first chunk (cached)", 50, 100, 100, false, false)
// Test 7: Read spanning all chunks
tracker.reset()
p := make([]byte, plaintextSize)
n, err := ra.ReadAt(p, 0)
if err != io.EOF {
t.Errorf("span all: got err=%v, want EOF", err)
}
if n != plaintextSize {
t.Errorf("span all: got n=%d, want %d", n, plaintextSize)
}
if !bytes.Equal(p, plaintext) {
t.Errorf("span all: data mismatch")
}
// Test 8: Read beyond the end (offset > size)
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(plaintextSize+100))
if err == nil {
t.Error("beyond end: expected error, got nil")
}
if n != 0 {
t.Errorf("beyond end: got n=%d, want 0", n)
}
// Test 9: Read with off = size (should return 0, EOF)
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(plaintextSize))
if err != io.EOF {
t.Errorf("off=size: got err=%v, want EOF", err)
}
if n != 0 {
t.Errorf("off=size: got n=%d, want 0", n)
}
// Test 10: Read spanning last chunk and beyond
tracker.reset()
p = make([]byte, 1000) // request more than available
n, err = ra.ReadAt(p, int64(2*cs+400))
if err != io.EOF {
t.Errorf("span last+beyond: got err=%v, want EOF", err)
}
wantN := 500 - 400 // only 100 bytes available from offset 2*cs+400
if n != wantN {
t.Errorf("span last+beyond: got n=%d, want %d", n, wantN)
}
if !bytes.Equal(p[:n], plaintext[2*cs+400:]) {
t.Error("span last+beyond: data mismatch")
}
// Test 11: Read spanning second+last chunk and beyond
tracker.reset()
p = make([]byte, cs+1000) // request more than available
n, err = ra.ReadAt(p, int64(cs+100))
if err != io.EOF {
t.Errorf("span 1-2+beyond: got err=%v, want EOF", err)
}
wantN = plaintextSize - (cs + 100)
if n != wantN {
t.Errorf("span 1-2+beyond: got n=%d, want %d", n, wantN)
}
if !bytes.Equal(p[:n], plaintext[cs+100:]) {
t.Error("span 1-2+beyond: data mismatch")
}
// Test 12: Negative offset
tracker.reset()
p = make([]byte, 100)
n, err = ra.ReadAt(p, -1)
if err == nil {
t.Error("negative offset: expected error, got nil")
}
if n != 0 {
t.Errorf("negative offset: got n=%d, want 0", n)
}
// Test 13: Zero-length read in the middle
tracker.reset()
p = make([]byte, 0)
n, err = ra.ReadAt(p, 100)
if err != nil {
t.Errorf("zero-length middle: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("zero-length middle: got n=%d, want 0", n)
}
// Test 14: Zero-length read at end
tracker.reset()
p = make([]byte, 0)
n, err = ra.ReadAt(p, int64(plaintextSize))
if err != nil {
t.Errorf("zero-length end: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("zero-length end: got n=%d, want 0", n)
}
// Test 15: Read exactly one chunk at chunk boundary
checkRead("exact chunk at boundary", int64(cs), cs, cs, false, true)
// Test 16: Read one byte at each chunk boundary
checkRead("one byte at start", 0, 1, 1, false, true)
checkRead("one byte at cs-1", int64(cs-1), 1, 1, false, false) // cached from test 15
checkRead("one byte at cs", int64(cs), 1, 1, false, true)
checkRead("one byte at 2*cs-1", int64(2*cs-1), 1, 1, false, false) // same chunk
checkRead("one byte at 2*cs", int64(2*cs), 1, 1, false, true)
checkRead("last byte", int64(plaintextSize-1), 1, 1, true, false) // same chunk, EOF because we reach end
// Test 17: Read crossing exactly one chunk boundary
checkRead("cross boundary 0-1", int64(cs-50), 100, 100, false, true)
checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true)
}
func TestDecryptReaderAtEmpty(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create empty encrypted file
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// Test 1: Read from empty file at offset 0
p := make([]byte, 100)
n, err := ra.ReadAt(p, 0)
if err != io.EOF {
t.Errorf("empty read: got err=%v, want EOF", err)
}
if n != 0 {
t.Errorf("empty read: got n=%d, want 0", n)
}
// Test 2: Zero-length read from empty file
p = make([]byte, 0)
n, err = ra.ReadAt(p, 0)
if err != nil {
t.Errorf("empty zero-length: got err=%v, want nil", err)
}
if n != 0 {
t.Errorf("empty zero-length: got n=%d, want 0", n)
}
// Test 3: Read beyond empty file
p = make([]byte, 100)
n, err = ra.ReadAt(p, 1)
if err == nil {
t.Error("empty beyond: expected error, got nil")
}
if n != 0 {
t.Errorf("empty beyond: got n=%d, want 0", n)
}
}
func TestDecryptReaderAtSingleChunk(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Single chunk, not full
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// All reads should use cache (final chunk = only chunk)
p := make([]byte, 100)
n, err := ra.ReadAt(p, 0)
if err != nil {
t.Errorf("single chunk start: got err=%v, want nil", err)
}
if n != 100 {
t.Errorf("single chunk start: got n=%d, want 100", n)
}
if tracker.called {
t.Error("single chunk start: unexpected src.ReadAt call")
}
if !bytes.Equal(p[:n], plaintext[:100]) {
t.Error("single chunk start: data mismatch")
}
// Read at end
n, err = ra.ReadAt(p, 900)
if err != io.EOF {
t.Errorf("single chunk end: got err=%v, want EOF", err)
}
if n != 100 {
t.Errorf("single chunk end: got n=%d, want 100", n)
}
if tracker.called {
t.Error("single chunk end: unexpected src.ReadAt call")
}
if !bytes.Equal(p[:n], plaintext[900:]) {
t.Error("single chunk end: data mismatch")
}
}
func TestDecryptReaderAtFullChunks(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Exactly 2 full chunks
plaintext := make([]byte, 2*cs)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
tracker.reset()
// Read last byte of second chunk (cached)
p := make([]byte, 1)
n, err := ra.ReadAt(p, int64(2*cs-1))
if err != io.EOF {
t.Errorf("last byte: got err=%v, want EOF", err)
}
if n != 1 {
t.Errorf("last byte: got n=%d, want 1", n)
}
if tracker.called {
t.Error("last byte: unexpected src.ReadAt call (should be cached)")
}
if p[0] != plaintext[2*cs-1] {
t.Error("last byte: data mismatch")
}
// Read at exactly the boundary between chunks
p = make([]byte, 100)
n, err = ra.ReadAt(p, int64(cs-50))
if err != nil {
t.Errorf("boundary: got err=%v, want nil", err)
}
if n != 100 {
t.Errorf("boundary: got n=%d, want 100", n)
}
if !bytes.Equal(p, plaintext[cs-50:cs+50]) {
t.Error("boundary: data mismatch")
}
}
func TestDecryptReaderAtWrongKey(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Try to decrypt with wrong key
wrongKey := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(wrongKey); err != nil {
t.Fatal(err)
}
_, err = stream.NewDecryptReaderAt(wrongKey, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err == nil {
t.Error("wrong key: expected error, got nil")
}
}
func TestDecryptReaderAtInvalidSize(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 1000)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Wrong size (too small)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)-1))
if err == nil {
t.Error("wrong size (small): expected error, got nil")
}
// Wrong size (too large)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)+1))
if err == nil {
t.Error("wrong size (large): expected error, got nil")
}
// Size that would imply empty final chunk (invalid)
// This would be: one full encrypted chunk + just overhead
invalidSize := int64(cs + chacha20poly1305.Overhead + chacha20poly1305.Overhead)
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(make([]byte, invalidSize)), invalidSize)
if err == nil {
t.Error("invalid size (empty final chunk): expected error, got nil")
}
}
func TestDecryptReaderAtTruncated(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 2*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Truncate ciphertext but lie about size
truncated := ciphertext[:len(ciphertext)-100]
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), int64(len(ciphertext)))
if err == nil {
t.Error("truncated: expected error, got nil")
}
}
func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create 4 chunks: 3 full + 1 partial
plaintext := make([]byte, 3*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
// Truncate to 3 chunks (remove the actual final chunk)
// The third chunk was NOT encrypted with the last chunk flag,
// so decryption should fail when we try to use it as the final chunk.
encChunkSize := cs + 16 // ChunkSize + Overhead
truncatedSize := int64(3 * encChunkSize)
truncated := ciphertext[:truncatedSize]
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), truncatedSize)
if err == nil {
t.Error("truncated at chunk boundary: expected error, got nil")
}
}
func TestDecryptReaderAtConcurrent(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create plaintext spanning 3 chunks: 2 full + partial
plaintextSize := 2*cs + 500
plaintext := make([]byte, plaintextSize)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
// Encrypt
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
t.Run("same chunk", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
off := int64((id*iterations + i) % 500)
p := make([]byte, 100)
n, err := ra.ReadAt(p, off)
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err)
return
}
if n != 100 {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n)
return
}
if !bytes.Equal(p, plaintext[off:off+100]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
t.Run("different chunks", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
// Each goroutine reads from a different chunk based on id
chunkIdx := id % 3
off := int64(chunkIdx*cs + (i % 400))
size := 100
if off+int64(size) > int64(plaintextSize) {
size = plaintextSize - int(off)
}
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if n == size && err == io.EOF {
err = nil // EOF at end is acceptable
}
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err)
return
}
if n != size {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
return
}
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
t.Run("across chunks", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
// Read across chunk boundaries
boundary := (id%2 + 1) * cs // either cs or 2*cs
off := int64(boundary - 50 + (i % 30))
size := 100
if off+int64(size) > int64(plaintextSize) {
size = plaintextSize - int(off)
}
if size <= 0 {
continue
}
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if n == size && err == io.EOF {
err = nil
}
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err)
return
}
if n != size {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
return
}
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
}
func TestDecryptReaderAtCorrupted(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
plaintext := make([]byte, 2*cs+500)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := bytes.Clone(buf.Bytes())
// Corrupt final chunk - should fail in constructor
corruptedFinal := bytes.Clone(ciphertext)
corruptedFinal[len(corruptedFinal)-10] ^= 0xFF
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFinal), int64(len(corruptedFinal)))
if err == nil {
t.Error("corrupted final: expected error, got nil")
}
// Corrupt first chunk - should fail on read
corruptedFirst := bytes.Clone(ciphertext)
corruptedFirst[10] ^= 0xFF
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFirst), int64(len(corruptedFirst)))
if err != nil {
t.Fatalf("corrupted first constructor: unexpected error: %v", err)
}
p := make([]byte, 100)
_, err = ra.ReadAt(p, 0)
if err == nil {
t.Error("corrupted first read: expected error, got nil")
}
}