internal/stream: fix DecryptReaderAt concurrency

This commit is contained in:
Filippo Valsorda
2025-12-26 21:56:55 +01:00
parent da2191789a
commit 420273952a
2 changed files with 168 additions and 1 deletions

View File

@@ -398,6 +398,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
var cacheUpdate *cachedChunk
chunk := make([]byte, encChunkSize)
for len(p) > 0 && off < r.size {
chunkIndex := off / ChunkSize
@@ -409,6 +410,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
var plaintext []byte
if cached != nil && cached.off == chunkOff {
plaintext = cached.data
cacheUpdate = nil
} else {
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
if err == io.EOF {
@@ -429,7 +431,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
if err != nil {
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
}
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
cacheUpdate = &cachedChunk{off: chunkOff, data: plaintext}
}
plainChunkOff := int(off - chunkIndex*ChunkSize)
@@ -439,6 +441,9 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
off += int64(copySize)
n += copySize
}
if cacheUpdate != nil {
r.cache.Store(cacheUpdate)
}
if off == r.size {
return n, io.EOF
}

View File

@@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
}
}
func TestDecryptReaderAtConcurrent(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
// Create plaintext spanning 3 chunks: 2 full + partial
plaintextSize := 2*cs + 500
plaintext := make([]byte, plaintextSize)
if _, err := rand.Read(plaintext); err != nil {
t.Fatal(err)
}
// Encrypt
buf := &bytes.Buffer{}
w, err := stream.NewEncryptWriter(key, buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(plaintext); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
ciphertext := buf.Bytes()
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
if err != nil {
t.Fatal(err)
}
t.Run("same chunk", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
off := int64((id*iterations + i) % 500)
p := make([]byte, 100)
n, err := ra.ReadAt(p, off)
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err)
return
}
if n != 100 {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n)
return
}
if !bytes.Equal(p, plaintext[off:off+100]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
t.Run("different chunks", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
// Each goroutine reads from a different chunk based on id
chunkIdx := id % 3
off := int64(chunkIdx*cs + (i % 400))
size := 100
if off+int64(size) > int64(plaintextSize) {
size = plaintextSize - int(off)
}
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if n == size && err == io.EOF {
err = nil // EOF at end is acceptable
}
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err)
return
}
if n != size {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
return
}
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
t.Run("across chunks", func(t *testing.T) {
t.Parallel()
const goroutines = 10
const iterations = 100
errc := make(chan error, goroutines)
for g := range goroutines {
go func(id int) {
for i := range iterations {
// Read across chunk boundaries
boundary := (id%2 + 1) * cs // either cs or 2*cs
off := int64(boundary - 50 + (i % 30))
size := 100
if off+int64(size) > int64(plaintextSize) {
size = plaintextSize - int(off)
}
if size <= 0 {
continue
}
p := make([]byte, size)
n, err := ra.ReadAt(p, off)
if n == size && err == io.EOF {
err = nil
}
if err != nil {
errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err)
return
}
if n != size {
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
return
}
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
return
}
}
errc <- nil
}(g)
}
for range goroutines {
if err := <-errc; err != nil {
t.Error(err)
}
}
})
}
func TestDecryptReaderAtCorrupted(t *testing.T) {
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {