diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 731c719..967cdc2 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -398,6 +398,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) { if len(p) == 0 { return 0, nil } + var cacheUpdate *cachedChunk chunk := make([]byte, encChunkSize) for len(p) > 0 && off < r.size { chunkIndex := off / ChunkSize @@ -409,6 +410,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) { var plaintext []byte if cached != nil && cached.off == chunkOff { plaintext = cached.data + cacheUpdate = nil } else { nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff) if err == io.EOF { @@ -429,7 +431,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) { if err != nil { return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err) } - r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext}) + cacheUpdate = &cachedChunk{off: chunkOff, data: plaintext} } plainChunkOff := int(off - chunkIndex*ChunkSize) @@ -439,6 +441,9 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) { off += int64(copySize) n += copySize } + if cacheUpdate != nil { + r.cache.Store(cacheUpdate) + } if off == r.size { return n, io.EOF } diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 1bd999f..7d0f7b3 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) { } } +func TestDecryptReaderAtConcurrent(t *testing.T) { + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + // Create plaintext spanning 3 chunks: 2 full + partial + plaintextSize := 2*cs + 500 + plaintext := make([]byte, plaintextSize) + if _, err := rand.Read(plaintext); err != nil { + t.Fatal(err) + } + + // Encrypt + buf := &bytes.Buffer{} + w, err := stream.NewEncryptWriter(key, buf) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(plaintext); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + ciphertext := buf.Bytes() + + ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext))) + if err != nil { + t.Fatal(err) + } + + t.Run("same chunk", func(t *testing.T) { + t.Parallel() + const goroutines = 10 + const iterations = 100 + errc := make(chan error, goroutines) + + for g := range goroutines { + go func(id int) { + for i := range iterations { + off := int64((id*iterations + i) % 500) + p := make([]byte, 100) + n, err := ra.ReadAt(p, off) + if err != nil { + errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err) + return + } + if n != 100 { + errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n) + return + } + if !bytes.Equal(p, plaintext[off:off+100]) { + errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i) + return + } + } + errc <- nil + }(g) + } + + for range goroutines { + if err := <-errc; err != nil { + t.Error(err) + } + } + }) + + t.Run("different chunks", func(t *testing.T) { + t.Parallel() + const goroutines = 10 + const iterations = 100 + errc := make(chan error, goroutines) + + for g := range goroutines { + go func(id int) { + for i := range iterations { + // Each goroutine reads from a different chunk based on id + chunkIdx := id % 3 + off := int64(chunkIdx*cs + (i % 400)) + size := 100 + if off+int64(size) > int64(plaintextSize) { + size = plaintextSize - int(off) + } + p := make([]byte, size) + n, err := ra.ReadAt(p, off) + if n == size && err == io.EOF { + err = nil // EOF at end is acceptable + } + if err != nil { + errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err) + return + } + if n != size { + errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size) + return + } + if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) { + errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i) + return + } + } + errc <- nil + }(g) + } + + for range goroutines { + if err := <-errc; err != nil { + t.Error(err) + } + } + }) + + t.Run("across chunks", func(t *testing.T) { + t.Parallel() + const goroutines = 10 + const iterations = 100 + errc := make(chan error, goroutines) + + for g := range goroutines { + go func(id int) { + for i := range iterations { + // Read across chunk boundaries + boundary := (id%2 + 1) * cs // either cs or 2*cs + off := int64(boundary - 50 + (i % 30)) + size := 100 + if off+int64(size) > int64(plaintextSize) { + size = plaintextSize - int(off) + } + if size <= 0 { + continue + } + p := make([]byte, size) + n, err := ra.ReadAt(p, off) + if n == size && err == io.EOF { + err = nil + } + if err != nil { + errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err) + return + } + if n != size { + errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size) + return + } + if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) { + errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i) + return + } + } + errc <- nil + }(g) + } + + for range goroutines { + if err := <-errc; err != nil { + t.Error(err) + } + } + }) +} + func TestDecryptReaderAtCorrupted(t *testing.T) { key := make([]byte, chacha20poly1305.KeySize) if _, err := rand.Read(key); err != nil {