Files
age/internal/stream/stream.go
2025-12-26 22:18:54 +01:00

452 lines
12 KiB
Go

// Copyright 2019 The age Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package stream implements a variant of the STREAM chunked encryption scheme.
package stream
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"io"
"sync/atomic"
"golang.org/x/crypto/chacha20poly1305"
)
const ChunkSize = 64 * 1024
func EncryptedChunkCount(encryptedSize int64) (int64, error) {
chunks := (encryptedSize + encChunkSize - 1) / encChunkSize
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
expChunks := (plaintextSize + ChunkSize - 1) / ChunkSize
// Empty plaintext, the only case that allows (and requires) an empty chunk.
if plaintextSize == 0 {
expChunks = 1
}
if expChunks != chunks {
return 0, fmt.Errorf("invalid encrypted payload size: %d", encryptedSize)
}
return chunks, nil
}
func PlaintextSize(encryptedSize int64) (int64, error) {
chunks, err := EncryptedChunkCount(encryptedSize)
if err != nil {
return 0, err
}
plaintextSize := encryptedSize - chunks*chacha20poly1305.Overhead
return plaintextSize, nil
}
type DecryptReader struct {
a cipher.AEAD
src io.Reader
unread []byte // decrypted but unread data, backed by buf
buf [encChunkSize]byte
err error
nonce [chacha20poly1305.NonceSize]byte
}
const (
encChunkSize = ChunkSize + chacha20poly1305.Overhead
lastChunkFlag = 0x01
)
func NewDecryptReader(key []byte, src io.Reader) (*DecryptReader, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &DecryptReader{a: aead, src: src}, nil
}
func (r *DecryptReader) Read(p []byte) (int, error) {
if len(r.unread) > 0 {
n := copy(p, r.unread)
r.unread = r.unread[n:]
return n, nil
}
if r.err != nil {
return 0, r.err
}
if len(p) == 0 {
return 0, nil
}
last, err := r.readChunk()
if err != nil {
r.err = err
return 0, err
}
n := copy(p, r.unread)
r.unread = r.unread[n:]
if last {
// Ensure there is an EOF after the last chunk as expected. In other
// words, check for trailing data after a full-length final chunk.
// Hopefully, the underlying reader supports returning EOF even if it
// had previously returned an EOF to ReadFull.
if _, err := r.src.Read(make([]byte, 1)); err == nil {
r.err = errors.New("trailing data after end of encrypted file")
} else if err != io.EOF {
r.err = fmt.Errorf("non-EOF error reading after end of encrypted file: %w", err)
} else {
r.err = io.EOF
}
}
return n, nil
}
// readChunk reads the next chunk of ciphertext from r.src and makes it available
// in r.unread. last is true if the chunk was marked as the end of the message.
// readChunk must not be called again after returning a last chunk or an error.
func (r *DecryptReader) readChunk() (last bool, err error) {
if len(r.unread) != 0 {
panic("stream: internal error: readChunk called with dirty buffer")
}
in := r.buf[:]
n, err := io.ReadFull(r.src, in)
switch {
case err == io.EOF:
// A message can't end without a marked chunk. This message is truncated.
return false, io.ErrUnexpectedEOF
case err == io.ErrUnexpectedEOF:
// The last chunk can be short, but not empty unless it's the first and
// only chunk.
if !nonceIsZero(&r.nonce) && n == r.a.Overhead() {
return false, errors.New("last chunk is empty, try age v1.0.0, and please consider reporting this")
}
in = in[:n]
last = true
setLastChunkFlag(&r.nonce)
case err != nil:
return false, err
}
outBuf := make([]byte, 0, ChunkSize)
out, err := r.a.Open(outBuf, r.nonce[:], in, nil)
if err != nil && !last {
// Check if this was a full-length final chunk.
last = true
setLastChunkFlag(&r.nonce)
out, err = r.a.Open(outBuf, r.nonce[:], in, nil)
}
if err != nil {
return false, errors.New("failed to decrypt and authenticate payload chunk, file may be corrupted or tampered with")
}
incNonce(&r.nonce)
r.unread = r.buf[:copy(r.buf[:], out)]
return last, nil
}
func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
for i := len(nonce) - 2; i >= 0; i-- {
nonce[i]++
if nonce[i] != 0 {
return
}
}
// The counter is 88 bits, this is unreachable.
panic("stream: chunk counter wrapped around")
}
func nonceForChunk(chunkIndex int64) *[chacha20poly1305.NonceSize]byte {
var nonce [chacha20poly1305.NonceSize]byte
binary.BigEndian.PutUint64(nonce[3:11], uint64(chunkIndex))
return &nonce
}
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
nonce[len(nonce)-1] = lastChunkFlag
}
func nonceIsZero(nonce *[chacha20poly1305.NonceSize]byte) bool {
return *nonce == [chacha20poly1305.NonceSize]byte{}
}
type EncryptWriter struct {
a cipher.AEAD
dst io.Writer
buf bytes.Buffer
nonce [chacha20poly1305.NonceSize]byte
err error
}
func NewEncryptWriter(key []byte, dst io.Writer) (*EncryptWriter, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &EncryptWriter{a: aead, dst: dst}, nil
}
func (w *EncryptWriter) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
total := len(p)
for len(p) > 0 {
n := min(len(p), ChunkSize-w.buf.Len())
w.buf.Write(p[:n])
p = p[n:]
// Only flush if there's a full chunk with bytes still to write, or we
// can't know if this is the last chunk yet.
if w.buf.Len() == ChunkSize && len(p) > 0 {
if err := w.flushChunk(notLastChunk); err != nil {
w.err = err
return 0, err
}
}
}
return total, nil
}
// Close flushes the last chunk. It does not close the underlying Writer.
func (w *EncryptWriter) Close() error {
if w.err != nil {
return w.err
}
w.err = w.flushChunk(lastChunk)
if w.err != nil {
return w.err
}
w.err = errors.New("stream.Writer is already closed")
return nil
}
const (
lastChunk = true
notLastChunk = false
)
func (w *EncryptWriter) flushChunk(last bool) error {
if !last && w.buf.Len() != ChunkSize {
panic("stream: internal error: flush called with partial chunk")
}
if last {
setLastChunkFlag(&w.nonce)
}
w.buf.Grow(chacha20poly1305.Overhead)
ciphertext := w.a.Seal(w.buf.Bytes()[:0], w.nonce[:], w.buf.Bytes(), nil)
_, err := w.dst.Write(ciphertext)
incNonce(&w.nonce)
w.buf.Reset()
return err
}
type EncryptReader struct {
a cipher.AEAD
src io.Reader
// The first ready bytes of buf are already encrypted. This may be less than
// buf.Len(), because we need to over-read to know if a chunk is the last.
ready int
buf bytes.Buffer
nonce [chacha20poly1305.NonceSize]byte
err error
}
func NewEncryptReader(key []byte, src io.Reader) (*EncryptReader, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &EncryptReader{a: aead, src: src}, nil
}
func (r *EncryptReader) Read(p []byte) (int, error) {
if r.ready > 0 {
n, err := r.buf.Read(p[:min(len(p), r.ready)])
r.ready -= n
return n, err
}
if r.err != nil {
return 0, r.err
}
if len(p) == 0 {
return 0, nil
}
if err := r.feedBuffer(); err != nil {
r.err = err
return 0, err
}
n, err := r.buf.Read(p[:min(len(p), r.ready)])
r.ready -= n
return n, err
}
// feedBuffer reads and encrypts the next chunk from r.src and appends it to
// r.buf. It sets r.ready to the number of newly available bytes in r.buf.
func (r *EncryptReader) feedBuffer() error {
if r.ready > 0 {
panic("stream: internal error: feedBuffer called with dirty buffer")
}
// CopyN will use r.buf.ReadFrom/WriteTo to fill the buffer directly.
// We need ChunkSize + 1 bytes to determine if this is the last chunk.
_, err := io.CopyN(&r.buf, r.src, int64(ChunkSize-r.buf.Len()+1))
if err != nil && err != io.EOF {
return err
}
if last := r.buf.Len() <= ChunkSize; last {
setLastChunkFlag(&r.nonce)
// After Grow, we know r.buf.Bytes() has enough capacity for the
// overhead. We encrypt in place and then do a Write to include the
// overhead in the buffer.
r.buf.Grow(chacha20poly1305.Overhead)
plaintext := r.buf.Bytes()
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
incNonce(&r.nonce)
r.buf.Write(plaintext[len(plaintext) : len(plaintext)+chacha20poly1305.Overhead])
r.ready = r.buf.Len()
r.err = io.EOF
return nil
}
// Same, but accounting for the tail byte which will remain unencrypted and
// needs to be shifted past the overhead.
if r.buf.Len() != ChunkSize+1 {
panic("stream: internal error: unexpected buffer length")
}
tailByte := r.buf.Bytes()[ChunkSize]
r.buf.Grow(chacha20poly1305.Overhead)
plaintext := r.buf.Bytes()[:ChunkSize]
r.a.Seal(plaintext[:0], r.nonce[:], plaintext, nil)
incNonce(&r.nonce)
r.buf.Write(plaintext[len(plaintext)+1 : len(plaintext)+chacha20poly1305.Overhead])
r.buf.WriteByte(tailByte)
r.ready = ChunkSize + chacha20poly1305.Overhead
return nil
}
type DecryptReaderAt struct {
a cipher.AEAD
src io.ReaderAt
size int64
chunks int64
cache atomic.Pointer[cachedChunk]
}
type cachedChunk struct {
off int64
data []byte
}
func NewDecryptReaderAt(key []byte, src io.ReaderAt, size int64) (*DecryptReaderAt, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
// Check that size is valid by decrypting the final chunk.
chunks, err := EncryptedChunkCount(size)
if err != nil {
return nil, err
}
finalChunkIndex := chunks - 1
finalChunkOff := finalChunkIndex * encChunkSize
finalChunkSize := size - finalChunkOff
finalChunk := make([]byte, finalChunkSize)
if _, err := src.ReadAt(finalChunk, finalChunkOff); err != nil {
return nil, fmt.Errorf("failed to read final chunk: %w", err)
}
nonce := nonceForChunk(finalChunkIndex)
setLastChunkFlag(nonce)
plaintext, err := aead.Open(finalChunk[:0], nonce[:], finalChunk, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt and authenticate final chunk: %w", err)
}
cache := &cachedChunk{off: finalChunkOff, data: plaintext}
plaintextSize := size - chunks*chacha20poly1305.Overhead
r := &DecryptReaderAt{a: aead, src: src, size: plaintextSize, chunks: chunks}
r.cache.Store(cache)
return r, nil
}
func (r *DecryptReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
if off < 0 || off > r.size {
return 0, fmt.Errorf("offset out of range [0:%d]: %d", r.size, off)
}
if len(p) == 0 {
return 0, nil
}
var cacheUpdate *cachedChunk
chunk := make([]byte, encChunkSize)
for len(p) > 0 && off < r.size {
chunkIndex := off / ChunkSize
chunkOff := chunkIndex * encChunkSize
encSize := r.size + r.chunks*chacha20poly1305.Overhead
chunkSize := min(encSize-chunkOff, encChunkSize)
cached := r.cache.Load()
var plaintext []byte
if cached != nil && cached.off == chunkOff {
plaintext = cached.data
cacheUpdate = nil
} else {
nn, err := r.src.ReadAt(chunk[:chunkSize], chunkOff)
if err == io.EOF {
if int64(nn) != chunkSize {
err = io.ErrUnexpectedEOF
} else {
err = nil
}
}
if err != nil {
return n, fmt.Errorf("failed to read chunk at offset %d: %w", chunkOff, err)
}
nonce := nonceForChunk(chunkIndex)
if chunkIndex == r.chunks-1 {
setLastChunkFlag(nonce)
}
plaintext, err = r.a.Open(chunk[:0], nonce[:], chunk[:chunkSize], nil)
if err != nil {
return n, fmt.Errorf("failed to decrypt and authenticate chunk at offset %d: %w", chunkOff, err)
}
cacheUpdate = &cachedChunk{off: chunkOff, data: plaintext}
}
plainChunkOff := int(off - chunkIndex*ChunkSize)
copySize := min(len(plaintext)-plainChunkOff, len(p))
copy(p, plaintext[plainChunkOff:plainChunkOff+copySize])
p = p[copySize:]
off += int64(copySize)
n += copySize
}
if cacheUpdate != nil {
r.cache.Store(cacheUpdate)
}
if off == r.size {
return n, io.EOF
}
return n, nil
}