mirror of
https://github.com/FiloSottile/age.git
synced 2026-01-08 21:03:05 +00:00
age: add DecryptReaderAt
This commit is contained in:
61
age.go
61
age.go
@@ -252,13 +252,12 @@ func (e *NoIdentityMatchError) Unwrap() []error {
|
||||
}
|
||||
|
||||
// Decrypt decrypts a file encrypted to one or more identities.
|
||||
//
|
||||
// It returns a Reader reading the decrypted plaintext of the age file read
|
||||
// from src. All identities will be tried until one successfully decrypts the file.
|
||||
// All identities will be tried until one successfully decrypts the file.
|
||||
// Native, non-interactive identities are tried before any other identities.
|
||||
//
|
||||
// If no identity matches the encrypted file, the returned error will be of type
|
||||
// [NoIdentityMatchError].
|
||||
// Decrypt returns a Reader reading the decrypted plaintext of the age file read
|
||||
// from src. If no identity matches the encrypted file, the returned error will
|
||||
// be of type [NoIdentityMatchError].
|
||||
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||
hdr, payload, err := format.Parse(src)
|
||||
if err != nil {
|
||||
@@ -278,6 +277,58 @@ func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||
return stream.NewDecryptReader(streamKey(fileKey, nonce), payload)
|
||||
}
|
||||
|
||||
// DecryptReaderAt decrypts a file encrypted to one or more identities.
|
||||
// All identities will be tried until one successfully decrypts the file.
|
||||
// Native, non-interactive identities are tried before any other identities.
|
||||
//
|
||||
// DecryptReaderAt takes an underlying [io.ReaderAt] and its total encrypted
|
||||
// size, and returns a ReaderAt of the decrypted plaintext and the plaintext
|
||||
// size. These can be used for example to instantiate an [io.SectionReader],
|
||||
// which implements [io.Reader] and [io.Seeker]. Note that ReaderAt by
|
||||
// definition disregards the seek position of src.
|
||||
//
|
||||
// The ReadAt method of the returned ReaderAt can be called concurrently.
|
||||
// The ReaderAt will internally cache the most recently decrypted chunk.
|
||||
// DecryptReaderAt reads and decrypts the final chunk before returning,
|
||||
// to authenticate the plaintext size.
|
||||
//
|
||||
// If no identity matches the encrypted file, the returned error will be of
|
||||
// type [NoIdentityMatchError].
|
||||
func DecryptReaderAt(src io.ReaderAt, encryptedSize int64, identities ...Identity) (io.ReaderAt, int64, error) {
|
||||
srcReader := io.NewSectionReader(src, 0, encryptedSize)
|
||||
hdr, payload, err := format.Parse(srcReader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
if err := hdr.Marshal(buf); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to serialize header: %w", err)
|
||||
}
|
||||
|
||||
fileKey, err := decryptHdr(hdr, identities...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, streamNonceSize)
|
||||
if _, err := io.ReadFull(payload, nonce); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to read nonce: %w", err)
|
||||
}
|
||||
|
||||
payloadOffset := int64(buf.Len()) + int64(len(nonce))
|
||||
payloadSize := encryptedSize - payloadOffset
|
||||
plaintextSize, err := stream.PlaintextSize(payloadSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
payloadReaderAt := io.NewSectionReader(src, payloadOffset, payloadSize)
|
||||
r, err := stream.NewDecryptReaderAt(streamKey(fileKey, nonce), payloadReaderAt, payloadSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return r, plaintextSize, nil
|
||||
}
|
||||
|
||||
func decryptHdr(hdr *format.Header, identities ...Identity) ([]byte, error) {
|
||||
if len(identities) == 0 {
|
||||
return nil, errors.New("no identities specified")
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"filippo.io/age/armor"
|
||||
"filippo.io/age/internal/format"
|
||||
"filippo.io/age/internal/stream"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
type Metadata struct {
|
||||
@@ -88,9 +87,9 @@ func Inspect(r io.Reader, fileSize int64) (*Metadata, error) {
|
||||
}
|
||||
data.Sizes.Armor = tr.count - fileSize
|
||||
}
|
||||
data.Sizes.Overhead = streamOverhead(fileSize - data.Sizes.Header)
|
||||
if data.Sizes.Overhead > fileSize-data.Sizes.Header {
|
||||
return nil, fmt.Errorf("payload too small to be a valid age file")
|
||||
data.Sizes.Overhead, err = streamOverhead(fileSize - data.Sizes.Header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compute stream overhead: %w", err)
|
||||
}
|
||||
data.Sizes.MinPayload = fileSize - data.Sizes.Header - data.Sizes.Overhead
|
||||
data.Sizes.MaxPayload = data.Sizes.MinPayload
|
||||
@@ -114,13 +113,15 @@ func (tr *trackReader) Read(p []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
func streamOverhead(payloadSize int64) int64 {
|
||||
func streamOverhead(payloadSize int64) (int64, error) {
|
||||
const streamNonceSize = 16
|
||||
const encChunkSize = stream.ChunkSize + chacha20poly1305.Overhead
|
||||
payloadSize -= streamNonceSize
|
||||
if payloadSize <= 0 {
|
||||
return streamNonceSize
|
||||
if payloadSize < streamNonceSize {
|
||||
return 0, fmt.Errorf("encrypted size too small: %d", payloadSize)
|
||||
}
|
||||
chunks := (payloadSize + encChunkSize - 1) / encChunkSize
|
||||
return streamNonceSize + chunks*chacha20poly1305.Overhead
|
||||
encryptedSize := payloadSize - streamNonceSize
|
||||
plaintextSize, err := stream.PlaintextSize(encryptedSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return payloadSize - plaintextSize, nil
|
||||
}
|
||||
|
||||
46
internal/inspect/inspect_test.go
Normal file
46
internal/inspect/inspect_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"filippo.io/age/internal/stream"
|
||||
)
|
||||
|
||||
func TestStreamOverhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
payloadSize int64
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{payloadSize: 0, wantErr: true},
|
||||
{payloadSize: 15, wantErr: true},
|
||||
{payloadSize: 16, wantErr: true},
|
||||
{payloadSize: 16 + 15, wantErr: true},
|
||||
{payloadSize: 16 + 16, want: 16 + 16}, // empty plaintext
|
||||
{payloadSize: 16 + 1 + 16, want: 16 + 16},
|
||||
{payloadSize: 16 + stream.ChunkSize + 16, want: 16 + 16},
|
||||
{payloadSize: 16 + stream.ChunkSize + 16 + 1, wantErr: true},
|
||||
{payloadSize: 16 + stream.ChunkSize + 16 + 15, wantErr: true},
|
||||
{payloadSize: 16 + stream.ChunkSize + 16 + 16, wantErr: true}, // empty final chunk
|
||||
{payloadSize: 16 + stream.ChunkSize + 16 + 1 + 16, want: 16 + 16 + 16},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
name := "payloadSize=" + fmt.Sprint(tt.payloadSize)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
got, gotErr := streamOverhead(tt.payloadSize)
|
||||
if gotErr != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("streamOverhead() failed: %v", gotErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Fatal("streamOverhead() succeeded unexpectedly")
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("streamOverhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,15 +8,42 @@ package stream
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const ChunkSize = 64 * 1024
|
||||
|
||||
func EncryptedChunkCount(encryptedSize int64) (int64, error) {
|
||||
chunks := (encryptedSize + encChunkSize - 1) / encChunkSize
|
||||
|
||||
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
|
||||
expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize
|
||||
// Empty plaintext, the only case that allows (and requires) an empty chunk.
|
||||
if plaintextSize == 0 {
|
||||
expChunks = 1
|
||||
}
|
||||
if expChunks != chunks {
|
||||
return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
func PlaintextSize(encryptedSize int64) (int64, error) {
|
||||
chunks, err := EncryptedChunkCount(encryptedSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
|
||||
return plaintextSize, nil
|
||||
}
|
||||
|
||||
type DecryptReader struct {
|
||||
a cipher.AEAD
|
||||
src io.Reader
|
||||
@@ -135,6 +162,12 @@ func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
panic("stream: chunk counter wrapped around")
|
||||
}
|
||||
|
||||
func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte {
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex))
|
||||
return &nonce
|
||||
}
|
||||
|
||||
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
nonce[len(nonce)-1] = lastChunkFlag
|
||||
}
|
||||
@@ -312,3 +345,102 @@ func (r *EncryptReader) feedBuffer() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type DecryptReaderAt struct {
|
||||
a cipher.AEAD
|
||||
src io.ReaderAt
|
||||
size int64
|
||||
chunks int64
|
||||
cache atomic.Pointer[cachedChunk]
|
||||
}
|
||||
|
||||
type cachedChunk struct {
|
||||
off int64
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check that size is valid by decrypting the final chunk.
|
||||
chunks, err := EncryptedChunkCount(size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finalChunkIndex := chunks - 1
|
||||
finalChunkOff := finalChunkIndex * encChunkSize
|
||||
finalChunkSize := size - finalChunkOff
|
||||
finalChunk := make([]byte, finalChunkSize)
|
||||
if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil {
|
||||
return nil, fmt.Errorf("failed to read final chunk: %w", err)
|
||||
}
|
||||
nonce := nonceForChunk(finalChunkIndex)
|
||||
setLastChunkFlag(nonce)
|
||||
plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err)
|
||||
}
|
||||
cache := &cachedChunk{off: finalChunkOff, data: plaintext}
|
||||
|
||||
plaintextSize := size - chunks*chacha20poly1305.Overhead
|
||||
r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks}
|
||||
r.cache.Store(cache)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
if off < 0 || off > r.size {
|
||||
return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off)
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
chunk := make([]byte, encChunkSize)
|
||||
for len(p) > 0 && off < r.size {
|
||||
chunkIndex := off / ChunkSize
|
||||
chunkOff := chunkIndex * encChunkSize
|
||||
encSize := r.size + r.chunks*chacha20poly1305.Overhead
|
||||
chunkSize := min(encSize-chunkOff, encChunkSize)
|
||||
|
||||
cached := r.cache.Load()
|
||||
var plaintext []byte
|
||||
if cached != nil && cached.off == chunkOff {
|
||||
plaintext = cached.data
|
||||
} else {
|
||||
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
|
||||
if err == io.EOF {
|
||||
if int64(nn) != chunkSize {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err)
|
||||
}
|
||||
nonce := nonceForChunk(chunkIndex)
|
||||
if chunkIndex == r.chunks-1 {
|
||||
setLastChunkFlag(nonce)
|
||||
}
|
||||
plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
|
||||
}
|
||||
r.cache.Store(&cachedChunk{off: chunkOff, data: plaintext})
|
||||
}
|
||||
|
||||
plainChunkOff := int(off - chunkIndex*ChunkSize)
|
||||
copySize := min(len(plaintext)-plainChunkOff, len(p))
|
||||
copy(p, plaintext[plainChunkOff:plainChunkOff+copySize])
|
||||
p = p[copySize:]
|
||||
off += int64(copySize)
|
||||
n += copySize
|
||||
}
|
||||
if off == r.size {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"filippo.io/age/internal/stream"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
@@ -20,13 +21,16 @@ const cs = stream.ChunkSize
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
for _, length := range []int{0, 1000, cs - 1, cs, cs + 1, cs + 100, 2 * cs, 2*cs + 500} {
|
||||
for _, stepSize := range []int{512, 600, 1000, cs - 1, cs, cs + 1} {
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
|
||||
testRoundTrip(t, stepSize, length)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
length, stepSize := 2*cs+500, 1
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), func(t *testing.T) {
|
||||
testRoundTrip(t, stepSize, length)
|
||||
})
|
||||
}
|
||||
|
||||
func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
@@ -34,85 +38,753 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
if _, err := rand.Read(src); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var ciphertext []byte
|
||||
|
||||
t.Run("EncryptWriter", func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
for n < length {
|
||||
b := min(length-n, stepSize)
|
||||
nn, err := w.Write(src[n : n+b])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != b {
|
||||
t.Errorf("Write returned %d, expected %d", nn, b)
|
||||
}
|
||||
n += nn
|
||||
|
||||
nn, err = w.Write(src[n:n])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != 0 {
|
||||
t.Errorf("Write returned %d, expected 0", nn)
|
||||
}
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Error("Close returned an error:", err)
|
||||
}
|
||||
|
||||
ciphertext = buf.Bytes()
|
||||
})
|
||||
|
||||
t.Run("DecryptReader", func(t *testing.T) {
|
||||
r, err := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
readBuf := make([]byte, stepSize)
|
||||
for n < length {
|
||||
nn, err := r.Read(readBuf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
}
|
||||
|
||||
t.Run("TestReader", func(t *testing.T) {
|
||||
if length > 1000 && testing.Short() {
|
||||
t.Skip("skipping slow iotest.TestReader on long input")
|
||||
}
|
||||
r, _ := stream.NewDecryptReader(key, bytes.NewReader(ciphertext))
|
||||
if err := iotest.TestReader(r, src); err != nil {
|
||||
t.Error("iotest.TestReader error on DecryptReader:", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DecryptReaderAt", func(t *testing.T) {
|
||||
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr := io.NewSectionReader(rAt, 0, int64(len(ciphertext)))
|
||||
|
||||
var n int
|
||||
readBuf := make([]byte, stepSize)
|
||||
for n < length {
|
||||
nn, err := rr.Read(readBuf)
|
||||
if n+nn == length && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAt error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
}
|
||||
|
||||
t.Run("TestReader", func(t *testing.T) {
|
||||
if length > 1000 && testing.Short() {
|
||||
t.Skip("skipping slow iotest.TestReader on long input")
|
||||
}
|
||||
rr := io.NewSectionReader(rAt, 0, int64(len(src)))
|
||||
if err := iotest.TestReader(rr, src); err != nil {
|
||||
t.Error("iotest.TestReader error on DecryptReaderAt:", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("EncryptReader", func(t *testing.T) {
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
readBuf := make([]byte, stepSize)
|
||||
for {
|
||||
nn, err := er.Read(readBuf)
|
||||
if nn == 0 && err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
|
||||
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
}
|
||||
if n != len(ciphertext) {
|
||||
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
|
||||
}
|
||||
|
||||
t.Run("TestReader", func(t *testing.T) {
|
||||
if length > 1000 && testing.Short() {
|
||||
t.Skip("skipping slow iotest.TestReader on long input")
|
||||
}
|
||||
er, _ := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||
if err := iotest.TestReader(er, ciphertext); err != nil {
|
||||
t.Error("iotest.TestReader error on EncryptReader:", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// trackingReaderAt wraps an io.ReaderAt and tracks whether ReadAt was called.
|
||||
type trackingReaderAt struct {
|
||||
r io.ReaderAt
|
||||
called bool
|
||||
}
|
||||
|
||||
func (t *trackingReaderAt) ReadAt(p []byte, off int64) (int, error) {
|
||||
t.called = true
|
||||
return t.r.ReadAt(p, off)
|
||||
}
|
||||
|
||||
func (t *trackingReaderAt) reset() {
|
||||
t.called = false
|
||||
}
|
||||
|
||||
func TestDecryptReaderAt(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create plaintext spanning exactly 3 chunks: 2 full chunks + partial third
|
||||
// Chunk 0: [0, cs)
|
||||
// Chunk 1: [cs, 2*cs)
|
||||
// Chunk 2: [2*cs, 2*cs+500)
|
||||
plaintextSize := 2*cs + 500
|
||||
plaintext := make([]byte, plaintextSize)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
for n < length {
|
||||
b := min(length-n, stepSize)
|
||||
nn, err := w.Write(src[n : n+b])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != b {
|
||||
t.Errorf("Write returned %d, expected %d", nn, b)
|
||||
}
|
||||
n += nn
|
||||
|
||||
nn, err = w.Write(src[n:n])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != 0 {
|
||||
t.Errorf("Write returned %d, expected 0", nn)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Error("Close returned an error:", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
t.Logf("buffer size: %d", buf.Len())
|
||||
ciphertext := bytes.Clone(buf.Bytes())
|
||||
// Create tracking ReaderAt
|
||||
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||
|
||||
r, err := stream.NewDecryptReader(key, buf)
|
||||
// Create DecryptReaderAt (this reads and caches the final chunk)
|
||||
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tracker.reset()
|
||||
|
||||
n = 0
|
||||
readBuf := make([]byte, stepSize)
|
||||
for n < length {
|
||||
nn, err := r.Read(readBuf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read error at index %d: %v", n, err)
|
||||
// Helper to check reads
|
||||
checkRead := func(name string, off int64, size int, wantN int, wantEOF bool, wantSrcRead bool) {
|
||||
t.Helper()
|
||||
tracker.reset()
|
||||
p := make([]byte, size)
|
||||
n, err := ra.ReadAt(p, off)
|
||||
|
||||
if wantEOF {
|
||||
if err != io.EOF {
|
||||
t.Errorf("%s: got err=%v, want EOF", name, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("%s: got err=%v, want nil", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||
if n != wantN {
|
||||
t.Errorf("%s: got n=%d, want %d", name, n, wantN)
|
||||
}
|
||||
|
||||
n += nn
|
||||
if tracker.called != wantSrcRead {
|
||||
t.Errorf("%s: src.ReadAt called=%v, want %v", name, tracker.called, wantSrcRead)
|
||||
}
|
||||
|
||||
// Verify data correctness
|
||||
if n > 0 && off >= 0 && off < int64(plaintextSize) {
|
||||
end := int(off) + n
|
||||
if end > plaintextSize {
|
||||
end = plaintextSize
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[off:end]) {
|
||||
t.Errorf("%s: data mismatch", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(src))
|
||||
// Test 1: Read from final chunk (cached by constructor)
|
||||
checkRead("final chunk (cached)", int64(2*cs+100), 100, 100, false, false)
|
||||
|
||||
// Test 2: Read spanning second and third chunk
|
||||
checkRead("span chunks 1-2", int64(cs+cs-50), 100, 100, false, true)
|
||||
|
||||
// Test 3: Read from final chunk again (cached from test 2)
|
||||
// When reading across chunks 1-2 in test 2, the loop processes chunk 1 then chunk 2,
|
||||
// so chunk 2 ends up in the cache.
|
||||
checkRead("final chunk after span", int64(2*cs+200), 100, 100, false, false)
|
||||
|
||||
// Test 4: Read from final chunk again (now cached)
|
||||
checkRead("final chunk (cached again)", int64(2*cs+50), 50, 50, false, false)
|
||||
|
||||
// Test 5: Read from first chunk (not cached)
|
||||
checkRead("first chunk", 0, 100, 100, false, true)
|
||||
|
||||
// Test 6: Read from first chunk again (now cached)
|
||||
checkRead("first chunk (cached)", 50, 100, 100, false, false)
|
||||
|
||||
// Test 7: Read spanning all chunks
|
||||
tracker.reset()
|
||||
p := make([]byte, plaintextSize)
|
||||
n, err := ra.ReadAt(p, 0)
|
||||
if err != io.EOF {
|
||||
t.Errorf("span all: got err=%v, want EOF", err)
|
||||
}
|
||||
if n != plaintextSize {
|
||||
t.Errorf("span all: got n=%d, want %d", n, plaintextSize)
|
||||
}
|
||||
if !bytes.Equal(p, plaintext) {
|
||||
t.Errorf("span all: data mismatch")
|
||||
}
|
||||
|
||||
// Test 8: Read beyond the end (offset > size)
|
||||
tracker.reset()
|
||||
p = make([]byte, 100)
|
||||
n, err = ra.ReadAt(p, int64(plaintextSize+100))
|
||||
if err == nil {
|
||||
t.Error("beyond end: expected error, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("beyond end: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 9: Read with off = size (should return 0, EOF)
|
||||
tracker.reset()
|
||||
p = make([]byte, 100)
|
||||
n, err = ra.ReadAt(p, int64(plaintextSize))
|
||||
if err != io.EOF {
|
||||
t.Errorf("off=size: got err=%v, want EOF", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("off=size: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 10: Read spanning last chunk and beyond
|
||||
tracker.reset()
|
||||
p = make([]byte, 1000) // request more than available
|
||||
n, err = ra.ReadAt(p, int64(2*cs+400))
|
||||
if err != io.EOF {
|
||||
t.Errorf("span last+beyond: got err=%v, want EOF", err)
|
||||
}
|
||||
wantN := 500 - 400 // only 100 bytes available from offset 2*cs+400
|
||||
if n != wantN {
|
||||
t.Errorf("span last+beyond: got n=%d, want %d", n, wantN)
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[2*cs+400:]) {
|
||||
t.Error("span last+beyond: data mismatch")
|
||||
}
|
||||
|
||||
// Test 11: Read spanning second+last chunk and beyond
|
||||
tracker.reset()
|
||||
p = make([]byte, cs+1000) // request more than available
|
||||
n, err = ra.ReadAt(p, int64(cs+100))
|
||||
if err != io.EOF {
|
||||
t.Errorf("span 1-2+beyond: got err=%v, want EOF", err)
|
||||
}
|
||||
wantN = plaintextSize - (cs + 100)
|
||||
if n != wantN {
|
||||
t.Errorf("span 1-2+beyond: got n=%d, want %d", n, wantN)
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[cs+100:]) {
|
||||
t.Error("span 1-2+beyond: data mismatch")
|
||||
}
|
||||
|
||||
// Test 12: Negative offset
|
||||
tracker.reset()
|
||||
p = make([]byte, 100)
|
||||
n, err = ra.ReadAt(p, -1)
|
||||
if err == nil {
|
||||
t.Error("negative offset: expected error, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("negative offset: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 13: Zero-length read in the middle
|
||||
tracker.reset()
|
||||
p = make([]byte, 0)
|
||||
n, err = ra.ReadAt(p, 100)
|
||||
if err != nil {
|
||||
t.Errorf("zero-length middle: got err=%v, want nil", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("zero-length middle: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 14: Zero-length read at end
|
||||
tracker.reset()
|
||||
p = make([]byte, 0)
|
||||
n, err = ra.ReadAt(p, int64(plaintextSize))
|
||||
if err != nil {
|
||||
t.Errorf("zero-length end: got err=%v, want nil", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("zero-length end: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 15: Read exactly one chunk at chunk boundary
|
||||
checkRead("exact chunk at boundary", int64(cs), cs, cs, false, true)
|
||||
|
||||
// Test 16: Read one byte at each chunk boundary
|
||||
checkRead("one byte at start", 0, 1, 1, false, true)
|
||||
checkRead("one byte at cs-1", int64(cs-1), 1, 1, false, false) // cached from test 15
|
||||
checkRead("one byte at cs", int64(cs), 1, 1, false, true)
|
||||
checkRead("one byte at 2*cs-1", int64(2*cs-1), 1, 1, false, false) // same chunk
|
||||
checkRead("one byte at 2*cs", int64(2*cs), 1, 1, false, true)
|
||||
checkRead("last byte", int64(plaintextSize-1), 1, 1, true, false) // same chunk, EOF because we reach end
|
||||
|
||||
// Test 17: Read crossing exactly one chunk boundary
|
||||
checkRead("cross boundary 0-1", int64(cs-50), 100, 100, false, true)
|
||||
checkRead("cross boundary 1-2", int64(2*cs-50), 100, 100, false, true)
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtEmpty(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create empty encrypted file
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
n = 0
|
||||
for {
|
||||
nn, err := er.Read(readBuf)
|
||||
if nn == 0 && err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatalf("EncryptReader Read error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], ciphertext[n:n+nn]) {
|
||||
t.Errorf("EncryptReader wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != len(ciphertext) {
|
||||
t.Errorf("EncryptReader read %d bytes, expected %d", n, len(ciphertext))
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tracker.reset()
|
||||
|
||||
// Test 1: Read from empty file at offset 0
|
||||
p := make([]byte, 100)
|
||||
n, err := ra.ReadAt(p, 0)
|
||||
if err != io.EOF {
|
||||
t.Errorf("empty read: got err=%v, want EOF", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("empty read: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 2: Zero-length read from empty file
|
||||
p = make([]byte, 0)
|
||||
n, err = ra.ReadAt(p, 0)
|
||||
if err != nil {
|
||||
t.Errorf("empty zero-length: got err=%v, want nil", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("empty zero-length: got n=%d, want 0", n)
|
||||
}
|
||||
|
||||
// Test 3: Read beyond empty file
|
||||
p = make([]byte, 100)
|
||||
n, err = ra.ReadAt(p, 1)
|
||||
if err == nil {
|
||||
t.Error("empty beyond: expected error, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("empty beyond: got n=%d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtSingleChunk(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Single chunk, not full
|
||||
plaintext := make([]byte, 1000)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tracker.reset()
|
||||
|
||||
// All reads should use cache (final chunk = only chunk)
|
||||
p := make([]byte, 100)
|
||||
n, err := ra.ReadAt(p, 0)
|
||||
if err != nil {
|
||||
t.Errorf("single chunk start: got err=%v, want nil", err)
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("single chunk start: got n=%d, want 100", n)
|
||||
}
|
||||
if tracker.called {
|
||||
t.Error("single chunk start: unexpected src.ReadAt call")
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[:100]) {
|
||||
t.Error("single chunk start: data mismatch")
|
||||
}
|
||||
|
||||
// Read at end
|
||||
n, err = ra.ReadAt(p, 900)
|
||||
if err != io.EOF {
|
||||
t.Errorf("single chunk end: got err=%v, want EOF", err)
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("single chunk end: got n=%d, want 100", n)
|
||||
}
|
||||
if tracker.called {
|
||||
t.Error("single chunk end: unexpected src.ReadAt call")
|
||||
}
|
||||
if !bytes.Equal(p[:n], plaintext[900:]) {
|
||||
t.Error("single chunk end: data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtFullChunks(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Exactly 2 full chunks
|
||||
plaintext := make([]byte, 2*cs)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
tracker := &trackingReaderAt{r: bytes.NewReader(ciphertext)}
|
||||
ra, err := stream.NewDecryptReaderAt(key, tracker, int64(len(ciphertext)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tracker.reset()
|
||||
|
||||
// Read last byte of second chunk (cached)
|
||||
p := make([]byte, 1)
|
||||
n, err := ra.ReadAt(p, int64(2*cs-1))
|
||||
if err != io.EOF {
|
||||
t.Errorf("last byte: got err=%v, want EOF", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("last byte: got n=%d, want 1", n)
|
||||
}
|
||||
if tracker.called {
|
||||
t.Error("last byte: unexpected src.ReadAt call (should be cached)")
|
||||
}
|
||||
if p[0] != plaintext[2*cs-1] {
|
||||
t.Error("last byte: data mismatch")
|
||||
}
|
||||
|
||||
// Read at exactly the boundary between chunks
|
||||
p = make([]byte, 100)
|
||||
n, err = ra.ReadAt(p, int64(cs-50))
|
||||
if err != nil {
|
||||
t.Errorf("boundary: got err=%v, want nil", err)
|
||||
}
|
||||
if n != 100 {
|
||||
t.Errorf("boundary: got n=%d, want 100", n)
|
||||
}
|
||||
if !bytes.Equal(p, plaintext[cs-50:cs+50]) {
|
||||
t.Error("boundary: data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtWrongKey(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := make([]byte, 1000)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
// Try to decrypt with wrong key
|
||||
wrongKey := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(wrongKey); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stream.NewDecryptReaderAt(wrongKey, bytes.NewReader(ciphertext), int64(len(ciphertext)))
|
||||
if err == nil {
|
||||
t.Error("wrong key: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtInvalidSize(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := make([]byte, 1000)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
// Wrong size (too small)
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)-1))
|
||||
if err == nil {
|
||||
t.Error("wrong size (small): expected error, got nil")
|
||||
}
|
||||
|
||||
// Wrong size (too large)
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(ciphertext), int64(len(ciphertext)+1))
|
||||
if err == nil {
|
||||
t.Error("wrong size (large): expected error, got nil")
|
||||
}
|
||||
|
||||
// Size that would imply empty final chunk (invalid)
|
||||
// This would be: one full encrypted chunk + just overhead
|
||||
invalidSize := int64(cs + chacha20poly1305.Overhead + chacha20poly1305.Overhead)
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(make([]byte, invalidSize)), invalidSize)
|
||||
if err == nil {
|
||||
t.Error("invalid size (empty final chunk): expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtTruncated(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := make([]byte, 2*cs+500)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
// Truncate ciphertext but lie about size
|
||||
truncated := ciphertext[:len(ciphertext)-100]
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), int64(len(ciphertext)))
|
||||
if err == nil {
|
||||
t.Error("truncated: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtTruncatedChunk(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create 4 chunks: 3 full + 1 partial
|
||||
plaintext := make([]byte, 3*cs+500)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := buf.Bytes()
|
||||
|
||||
// Truncate to 3 chunks (remove the actual final chunk)
|
||||
// The third chunk was NOT encrypted with the last chunk flag,
|
||||
// so decryption should fail when we try to use it as the final chunk.
|
||||
encChunkSize := cs + 16 // ChunkSize + Overhead
|
||||
truncatedSize := int64(3 * encChunkSize)
|
||||
truncated := ciphertext[:truncatedSize]
|
||||
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(truncated), truncatedSize)
|
||||
if err == nil {
|
||||
t.Error("truncated at chunk boundary: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptReaderAtCorrupted(t *testing.T) {
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := make([]byte, 2*cs+500)
|
||||
if _, err := rand.Read(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext := bytes.Clone(buf.Bytes())
|
||||
|
||||
// Corrupt final chunk - should fail in constructor
|
||||
corruptedFinal := bytes.Clone(ciphertext)
|
||||
corruptedFinal[len(corruptedFinal)-10] ^= 0xFF
|
||||
_, err = stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFinal), int64(len(corruptedFinal)))
|
||||
if err == nil {
|
||||
t.Error("corrupted final: expected error, got nil")
|
||||
}
|
||||
|
||||
// Corrupt first chunk - should fail on read
|
||||
corruptedFirst := bytes.Clone(ciphertext)
|
||||
corruptedFirst[10] ^= 0xFF
|
||||
ra, err := stream.NewDecryptReaderAt(key, bytes.NewReader(corruptedFirst), int64(len(corruptedFirst)))
|
||||
if err != nil {
|
||||
t.Fatalf("corrupted first constructor: unexpected error: %v", err)
|
||||
}
|
||||
p := make([]byte, 100)
|
||||
_, err = ra.ReadAt(p, 0)
|
||||
if err == nil {
|
||||
t.Error("corrupted first read: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
192
testkit_test.go
192
testkit_test.go
@@ -140,10 +140,16 @@ func parseVector(t *testing.T, test []byte) *vector {
|
||||
}
|
||||
|
||||
func TestVectors(t *testing.T) {
|
||||
forEachVector(t, testVector)
|
||||
forEachVector(t, func(t *testing.T, v *vector) {
|
||||
var plaintext []byte
|
||||
t.Run("Decrypt", func(t *testing.T) { plaintext = testDecrypt(t, v) })
|
||||
t.Run("DecryptReaderAt", func(t *testing.T) { testDecryptReaderAt(t, v, plaintext) })
|
||||
t.Run("Inspect", func(t *testing.T) { testInspect(t, v, plaintext) })
|
||||
t.Run("RoundTrip", func(t *testing.T) { testVectorRoundTrip(t, v) })
|
||||
})
|
||||
}
|
||||
|
||||
func testVector(t *testing.T, v *vector) {
|
||||
func testDecrypt(t *testing.T, v *vector) []byte {
|
||||
var in io.Reader = bytes.NewReader(v.file)
|
||||
if v.armored {
|
||||
in = armor.NewReader(in)
|
||||
@@ -152,25 +158,25 @@ func testVector(t *testing.T, v *vector) {
|
||||
if err != nil && strings.HasSuffix(err.Error(), "bad header MAC") {
|
||||
if v.expect == "HMAC failure" {
|
||||
t.Log(err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
t.Fatalf("expected %s, got HMAC error", v.expect)
|
||||
} else if e := new(armor.Error); errors.As(err, &e) {
|
||||
if v.expect == "armor failure" {
|
||||
t.Log(err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||
} else if _, ok := err.(*age.NoIdentityMatchError); ok {
|
||||
if v.expect == "no match" {
|
||||
t.Log(err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||
} else if err != nil {
|
||||
if v.expect == "header failure" {
|
||||
t.Log(err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
t.Fatalf("expected %s, got: %v", v.expect, err)
|
||||
} else if v.expect != "success" && v.expect != "payload failure" &&
|
||||
@@ -188,15 +194,77 @@ func testVector(t *testing.T, v *vector) {
|
||||
}
|
||||
}
|
||||
if v.payloadHash != nil && sha256.Sum256(out) != *v.payloadHash {
|
||||
t.Error("partial payload hash mismatch")
|
||||
t.Errorf("partial payload hash mismatch, read %d bytes", len(out))
|
||||
}
|
||||
return
|
||||
return out
|
||||
} else if v.expect != "success" {
|
||||
t.Fatalf("expected %s, got success", v.expect)
|
||||
}
|
||||
if sha256.Sum256(out) != *v.payloadHash {
|
||||
t.Error("payload hash mismatch")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func testDecryptReaderAt(t *testing.T, v *vector, plaintext []byte) {
|
||||
if v.armored {
|
||||
t.Skip("armor.NewReader does not implement ReaderAt")
|
||||
}
|
||||
rAt, s, err := age.DecryptReaderAt(bytes.NewReader(v.file), int64(len(v.file)), v.identities...)
|
||||
switch v.expect {
|
||||
case "success":
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got: %v", err)
|
||||
}
|
||||
if int64(len(plaintext)) != s {
|
||||
t.Errorf("unexpected size: got %d, want %d", s, len(plaintext))
|
||||
}
|
||||
case "payload failure":
|
||||
// DecryptReaderAt detects some (but not all) payload failures upfront,
|
||||
// either from the size of the payload, or by decrypting the last chunk
|
||||
// to authenticate its size.
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
}
|
||||
t.Fatalf("expected %s, got success", v.expect)
|
||||
}
|
||||
out, err := io.ReadAll(io.NewSectionReader(rAt, 0, s))
|
||||
if v.expect == "success" {
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Fatalf("expected %s, got success", v.expect)
|
||||
}
|
||||
t.Log(err)
|
||||
// We can't check the partial payload hash, because the ReaderAt will
|
||||
// notice errors that a linearly scanning Reader could not. For example,
|
||||
// if there are two final chunks, the linear Reader will decrypt the
|
||||
// first one and then error out on the second, while the ReaderAt will
|
||||
// decrypt the second one to check the size, and then know that the
|
||||
// first chunk could not be the last one. Instead, check that the
|
||||
// prefix, if any, matches.
|
||||
if !bytes.HasPrefix(plaintext, out) {
|
||||
t.Errorf("partial payload prefix mismatch, read %d bytes", len(out))
|
||||
}
|
||||
return
|
||||
}
|
||||
if sha256.Sum256(out) != *v.payloadHash {
|
||||
t.Error("payload hash mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func testInspect(t *testing.T, v *vector, plaintext []byte) {
|
||||
if v.expect != "success" {
|
||||
t.Skip("invalid file, can't inspect")
|
||||
}
|
||||
for _, fileSize := range []int64{int64(len(v.file)), -1} {
|
||||
metadata, err := inspect.Inspect(bytes.NewReader(v.file), fileSize)
|
||||
if err != nil {
|
||||
@@ -211,8 +279,8 @@ func testVector(t *testing.T, v *vector) {
|
||||
if metadata.Sizes.Armor+metadata.Sizes.Header+metadata.Sizes.Overhead+metadata.Sizes.MinPayload != int64(len(v.file)) {
|
||||
t.Errorf("size breakdown does not add up to file size")
|
||||
}
|
||||
if metadata.Sizes.MinPayload != int64(len(out)) {
|
||||
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(out))
|
||||
if metadata.Sizes.MinPayload != int64(len(plaintext)) {
|
||||
t.Errorf("unexpected payload size: got %d, want %d", metadata.Sizes.MinPayload, len(plaintext))
|
||||
}
|
||||
if metadata.Sizes.MaxPayload != metadata.Sizes.MinPayload {
|
||||
t.Errorf("unexpected max payload size: got %d, want %d", metadata.Sizes.MaxPayload, metadata.Sizes.MinPayload)
|
||||
@@ -223,16 +291,12 @@ func testVector(t *testing.T, v *vector) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
|
||||
// testVectorsRoundTrip checks that any (valid) armor, header, and/or STREAM
|
||||
// payload in the test vectors re-encodes identically.
|
||||
func TestVectorsRoundTrip(t *testing.T) {
|
||||
forEachVector(t, testVectorRoundTrip)
|
||||
}
|
||||
|
||||
func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
if v.armored {
|
||||
if v.expect == "armor failure" {
|
||||
t.SkipNow()
|
||||
t.Skip("invalid armor, nothing to round-trip")
|
||||
}
|
||||
t.Run("armor", func(t *testing.T) {
|
||||
payload, err := io.ReadAll(armor.NewReader(bytes.NewReader(v.file)))
|
||||
@@ -261,7 +325,7 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
}
|
||||
|
||||
if v.expect == "header failure" {
|
||||
t.SkipNow()
|
||||
t.Skip("invalid header, nothing to round-trip")
|
||||
}
|
||||
hdr, p, err := format.Parse(bytes.NewReader(v.file))
|
||||
if err != nil {
|
||||
@@ -283,46 +347,62 @@ func testVectorRoundTrip(t *testing.T, v *vector) {
|
||||
}
|
||||
})
|
||||
|
||||
if v.expect == "success" {
|
||||
t.Run("STREAM", func(t *testing.T) {
|
||||
nonce, payload := payload[:16], payload[16:]
|
||||
key := streamKey(v.fileKey[:], nonce)
|
||||
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
plaintext, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), payload) {
|
||||
t.Error("got a different STREAM ciphertext")
|
||||
}
|
||||
buf.Reset()
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext, err := io.ReadAll(er)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(ciphertext, payload) {
|
||||
t.Error("got a different STREAM ciphertext from EncryptReader")
|
||||
}
|
||||
})
|
||||
if v.expect != "success" {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("STREAM", func(t *testing.T) {
|
||||
nonce, payload := payload[:16], payload[16:]
|
||||
key := streamKey(v.fileKey[:], nonce)
|
||||
|
||||
r, err := stream.NewDecryptReader(key, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
plaintext, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rAt, err := stream.NewDecryptReaderAt(key, bytes.NewReader(payload), int64(len(payload)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
plaintextAt, err := io.ReadAll(io.NewSectionReader(rAt, 0, int64(len(plaintext))))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(plaintextAt, plaintext) {
|
||||
t.Errorf("got a different plaintext from DecryptReaderAt")
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := stream.NewEncryptWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(plaintext); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), payload) {
|
||||
t.Error("got a different STREAM ciphertext")
|
||||
}
|
||||
|
||||
er, err := stream.NewEncryptReader(key, bytes.NewReader(plaintext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ciphertext, err := io.ReadAll(er)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(ciphertext, payload) {
|
||||
t.Error("got a different STREAM ciphertext from EncryptReader")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func streamKey(fileKey, nonce []byte) []byte {
|
||||
|
||||
Reference in New Issue
Block a user