mirror of
https://github.com/FiloSottile/age.git
synced 2026-01-03 10:55:14 +00:00
internal/stream: fix DecryptReaderAt concurrency
This commit is contained in:
@@ -398,6 +398,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var cacheUpdate *cachedChunk
|
||||
chunk := make([]byte, encChunkSize)
|
||||
for len(p) > 0 && off < r.size {
|
||||
chunkIndex := off / ChunkSize
|
||||
@@ -409,6 +410,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
var plaintext []byte
|
||||
if cached != nil && cached.off == chunkOff {
|
||||
plaintext = cached.data
|
||||
cacheUpdate = nil
|
||||
} else {
|
||||
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
|
||||
if err == io.EOF {
|
||||
@@ -429,7 +431,7 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
|
||||
}
|
||||
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
|
||||
cacheUpdate = &cachedChunk{off: chunkOff, data: plaintext}
|
||||
}
|
||||
|
||||
plainChunkOff := int(off - chunkIndex*ChunkSize)
|
||||
@@ -439,6 +441,9 @@ func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
off += int64(copySize)
|
||||
n += copySize
|
||||
}
|
||||
if cacheUpdate != nil {
|
||||
r.cache.Store(cacheUpdate)
|
||||
}
|
||||
if off == r.size {
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
@@ -743,6 +743,168 @@ func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtConcurrent(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create plaintext spanning 3 chunks: 2 full + partial
|
||||
plaintextSize := 2*cs + 500
|
||||
plaintext := make([]byte, plaintextSize)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("same chunk", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const goroutines = 10
|
||||
const iterations = 100
|
||||
errc := make(chan error, goroutines)
|
||||
|
||||
for g := range goroutines {
|
||||
go func(id int) {
|
||||
for i := range iterations {
|
||||
off := int64((id*iterations + i) % 500)
|
||||
p := make([]byte, 100)
|
||||
n, err := ra.ReadAt(p, off)
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: %v", id, i, err)
|
||||
return
|
||||
}
|
||||
if n != 100 {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want 100", id, i, n)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(p, plaintext[off:off+100]) {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
errc <- nil
|
||||
}(g)
|
||||
}
|
||||
|
||||
for range goroutines {
|
||||
if err := <-errc; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different chunks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const goroutines = 10
|
||||
const iterations = 100
|
||||
errc := make(chan error, goroutines)
|
||||
|
||||
for g := range goroutines {
|
||||
go func(id int) {
|
||||
for i := range iterations {
|
||||
// Each goroutine reads from a different chunk based on id
|
||||
chunkIdx := id % 3
|
||||
off := int64(chunkIdx*cs + (i % 400))
|
||||
size := 100
|
||||
if off+int64(size) > int64(plaintextSize) {
|
||||
size = plaintextSize - int(off)
|
||||
}
|
||||
p := make([]byte, size)
|
||||
n, err := ra.ReadAt(p, off)
|
||||
if n == size && err == io.EOF {
|
||||
err = nil // EOF at end is acceptable
|
||||
}
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: off=%d: %v", id, i, off, err)
|
||||
return
|
||||
}
|
||||
if n != size {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
errc <- nil
|
||||
}(g)
|
||||
}
|
||||
|
||||
for range goroutines {
|
||||
if err := <-errc; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("across chunks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const goroutines = 10
|
||||
const iterations = 100
|
||||
errc := make(chan error, goroutines)
|
||||
|
||||
for g := range goroutines {
|
||||
go func(id int) {
|
||||
for i := range iterations {
|
||||
// Read across chunk boundaries
|
||||
boundary := (id%2 + 1) * cs // either cs or 2*cs
|
||||
off := int64(boundary - 50 + (i % 30))
|
||||
size := 100
|
||||
if off+int64(size) > int64(plaintextSize) {
|
||||
size = plaintextSize - int(off)
|
||||
}
|
||||
if size <= 0 {
|
||||
continue
|
||||
}
|
||||
p := make([]byte, size)
|
||||
n, err := ra.ReadAt(p, off)
|
||||
if n == size && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: off=%d size=%d: %v", id, i, off, size, err)
|
||||
return
|
||||
}
|
||||
if n != size {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: n=%d, want %d", id, i, n, size)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[off:off+int64(n)]) {
|
||||
errc <- fmt.Errorf("goroutine %d iter %d: data mismatch", id, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
errc <- nil
|
||||
}(g)
|
||||
}
|
||||
|
||||
for range goroutines {
|
||||
if err := <-errc; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtCorrupted(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user