internal: implement STREAM, key exchange, encryption and decryption

Developed live over 6 hours of streaming on Twitch.

https://twitter.com/FiloSottile/status/1180875486911766528
This commit is contained in:
Filippo Valsorda
2019-10-06 21:19:04 -04:00
parent 52dbe9eecf
commit e9c118cea0
13 changed files with 1097 additions and 4 deletions

2
go.mod
View File

@@ -1,3 +1,5 @@
module github.com/FiloSottile/age
go 1.12
require golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc

8
go.sum Normal file
View File

@@ -0,0 +1,8 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

118
internal/age/age.go Normal file
View File

@@ -0,0 +1,118 @@
package age
import (
"crypto/hmac"
"crypto/rand"
"errors"
"fmt"
"io"
"github.com/FiloSottile/age/internal/format"
"github.com/FiloSottile/age/internal/stream"
)
type Identity interface {
Type() string
Unwrap(block *format.Recipient) (fileKey []byte, err error)
}
type Recipient interface {
Type() string
Wrap(fileKey []byte) (*format.Recipient, error)
}
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
if len(recipients) == 0 {
return nil, errors.New("no recipients specified")
}
fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey); err != nil {
return nil, err
}
hdr := &format.Header{}
// TODO: remove the AEAD marker from v1.
hdr.AEAD = "ChaChaPoly"
for i, r := range recipients {
if r.Type() == "scrypt" && len(recipients) != 1 {
return nil, errors.New("an scrypt recipient must be the only one")
}
block, err := r.Wrap(fileKey)
if err != nil {
return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
}
hdr.Recipients = append(hdr.Recipients, block)
}
if mac, err := headerMAC(fileKey, hdr); err != nil {
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
} else {
hdr.MAC = mac
}
if err := hdr.Marshal(dst); err != nil {
return nil, fmt.Errorf("failed to write header: %v", err)
}
nonce := make([]byte, 16)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
if _, err := dst.Write(nonce); err != nil {
return nil, fmt.Errorf("failed to write nonce: %v", err)
}
return stream.NewWriter(streamKey(fileKey, nonce), dst)
}
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
if len(identities) == 0 {
return nil, errors.New("no identities specified")
}
hdr, payload, err := format.Parse(src)
if err != nil {
return nil, fmt.Errorf("failed to read header: %v", err)
}
if hdr.AEAD != "ChaChaPoly" {
return nil, fmt.Errorf("unsupported AEAD: %v", hdr.AEAD)
}
if len(hdr.Recipients) > 20 {
return nil, errors.New("too many recipients")
}
var fileKey []byte
RecipientsLoop:
for _, r := range hdr.Recipients {
if r.Type == "scrypt" && len(hdr.Recipients) != 1 {
return nil, errors.New("an scrypt recipient must be the only one")
}
for _, i := range identities {
if i.Type() != r.Type {
continue
}
fileKey, err = i.Unwrap(r)
if err == nil {
break RecipientsLoop
}
}
}
if fileKey == nil {
return nil, errors.New("no identity matched a recipient")
}
if mac, err := headerMAC(fileKey, hdr); err != nil {
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
} else if !hmac.Equal(mac, hdr.MAC) {
return nil, errors.New("bad header MAC")
}
nonce := make([]byte, 16)
if _, err := io.ReadFull(payload, nonce); err != nil {
return nil, fmt.Errorf("failed to read nonce: %v", err)
}
return stream.NewReader(streamKey(fileKey, nonce), payload)
}

103
internal/age/age_test.go Normal file
View File

@@ -0,0 +1,103 @@
package age_test
import (
"bytes"
"crypto/rand"
"io"
"io/ioutil"
"testing"
"github.com/FiloSottile/age/internal/age"
"golang.org/x/crypto/curve25519"
)
const helloWorld = "Hello, Twitch!"
func TestEncryptDecryptX25519(t *testing.T) {
var secretKeyA, publicKeyA, secretKeyB, publicKeyB [32]byte
if _, err := rand.Read(secretKeyA[:]); err != nil {
t.Fatal(err)
}
if _, err := rand.Read(secretKeyB[:]); err != nil {
t.Fatal(err)
}
curve25519.ScalarBaseMult(&publicKeyA, &secretKeyA)
curve25519.ScalarBaseMult(&publicKeyB, &secretKeyB)
rA, err := age.NewX25519Recipient(publicKeyA[:])
if err != nil {
t.Fatal(err)
}
rB, err := age.NewX25519Recipient(publicKeyB[:])
if err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
w, err := age.Encrypt(buf, rA, rB)
if err != nil {
t.Fatal(err)
}
if _, err := io.WriteString(w, helloWorld); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
t.Logf("%s", buf.Bytes())
i, err := age.NewX25519Identity(secretKeyB[:])
if err != nil {
t.Fatal(err)
}
out, err := age.Decrypt(buf, i)
if err != nil {
t.Fatal(err)
}
outBytes, err := ioutil.ReadAll(out)
if err != nil {
t.Fatal(err)
}
if string(outBytes) != helloWorld {
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
}
}
func TestEncryptDecryptScrypt(t *testing.T) {
password := "twitch.tv/filosottile"
r, err := age.NewScryptRecipient(password)
if err != nil {
t.Fatal(err)
}
r.SetWorkFactor(1 << 15)
buf := &bytes.Buffer{}
w, err := age.Encrypt(buf, r)
if err != nil {
t.Fatal(err)
}
if _, err := io.WriteString(w, helloWorld); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
t.Logf("%s", buf.Bytes())
i, err := age.NewScryptIdentity(password)
if err != nil {
t.Fatal(err)
}
out, err := age.Decrypt(buf, i)
if err != nil {
t.Fatal(err)
}
outBytes, err := ioutil.ReadAll(out)
if err != nil {
t.Fatal(err)
}
if string(outBytes) != helloWorld {
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
}
}

View File

@@ -0,0 +1,51 @@
package age
import (
"crypto/hmac"
"crypto/sha256"
"io"
"github.com/FiloSottile/age/internal/format"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
func aeadEncrypt(key, plaintext []byte) ([]byte, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
nonce := make([]byte, chacha20poly1305.NonceSize)
return aead.Seal(nil, nonce, plaintext, nil), nil
}
func aeadDecrypt(key, ciphertext []byte) ([]byte, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
nonce := make([]byte, chacha20poly1305.NonceSize)
return aead.Open(nil, nonce, ciphertext, nil)
}
func headerMAC(fileKey []byte, hdr *format.Header) ([]byte, error) {
h := hkdf.New(sha256.New, fileKey, nil, []byte("header"))
hmacKey := make([]byte, 32)
if _, err := io.ReadFull(h, hmacKey); err != nil {
return nil, err
}
hh := hmac.New(sha256.New, hmacKey)
if err := hdr.MarshalWithoutMAC(hh); err != nil {
return nil, err
}
return hh.Sum(nil), nil
}
func streamKey(fileKey, nonce []byte) []byte {
h := hkdf.New(sha256.New, fileKey, nonce, []byte("payload"))
streamKey := make([]byte, chacha20poly1305.KeySize)
if _, err := io.ReadFull(h, streamKey); err != nil {
panic("age: internal error: failed to read from HKDF: " + err.Error())
}
return streamKey
}

View File

@@ -0,0 +1,131 @@
package age_test
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/FiloSottile/age/internal/age"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/ssh"
)
func TestX25519RoundTrip(t *testing.T) {
var secretKey, publicKey, fileKey [32]byte
if _, err := rand.Read(secretKey[:]); err != nil {
t.Fatal(err)
}
if _, err := rand.Read(fileKey[:]); err != nil {
t.Fatal(err)
}
curve25519.ScalarBaseMult(&publicKey, &secretKey)
r, err := age.NewX25519Recipient(publicKey[:])
if err != nil {
t.Fatal(err)
}
i, err := age.NewX25519Identity(secretKey[:])
if err != nil {
t.Fatal(err)
}
if r.Type() != i.Type() || r.Type() != "X25519" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}
block, err := r.Wrap(fileKey[:])
if err != nil {
t.Fatal(err)
}
t.Logf("%#v", block)
out, err := i.Unwrap(block)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fileKey[:], out) {
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
}
}
func TestScryptRoundTrip(t *testing.T) {
password := "twitch.tv/filosottile"
r, err := age.NewScryptRecipient(password)
if err != nil {
t.Fatal(err)
}
r.SetWorkFactor(1 << 15)
i, err := age.NewScryptIdentity(password)
if err != nil {
t.Fatal(err)
}
if r.Type() != i.Type() || r.Type() != "scrypt" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}
fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey[:]); err != nil {
t.Fatal(err)
}
block, err := r.Wrap(fileKey[:])
if err != nil {
t.Fatal(err)
}
t.Logf("%#v", block)
out, err := i.Unwrap(block)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fileKey[:], out) {
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
}
}
func TestSSHRSARoundTrip(t *testing.T) {
pk, err := rsa.GenerateKey(rand.Reader, 768)
if err != nil {
t.Fatal(err)
}
pub, err := ssh.NewPublicKey(&pk.PublicKey)
if err != nil {
t.Fatal(err)
}
r, err := age.NewSSHRSARecipient(pub)
if err != nil {
t.Fatal(err)
}
i, err := age.NewSSHRSAIdentity(pk)
if err != nil {
t.Fatal(err)
}
if r.Type() != i.Type() || r.Type() != "ssh-rsa" {
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
}
fileKey := make([]byte, 16)
if _, err := rand.Read(fileKey[:]); err != nil {
t.Fatal(err)
}
block, err := r.Wrap(fileKey[:])
if err != nil {
t.Fatal(err)
}
t.Logf("%#v", block)
out, err := i.Unwrap(block)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(fileKey[:], out) {
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
}
}

125
internal/age/scrypt.go Normal file
View File

@@ -0,0 +1,125 @@
package age
import (
"crypto/rand"
"errors"
"fmt"
"strconv"
"github.com/FiloSottile/age/internal/format"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/scrypt"
)
type ScryptRecipient struct {
password []byte
workFactor int
}
var _ Recipient = &ScryptRecipient{}
func (*ScryptRecipient) Type() string { return "scrypt" }
func NewScryptRecipient(password string) (*ScryptRecipient, error) {
if len(password) == 0 {
return nil, errors.New("empty scrypt password")
}
r := &ScryptRecipient{
password: []byte(password),
workFactor: 1 << 18, // 1s on a modern machine
}
return r, nil
}
func (r *ScryptRecipient) SetWorkFactor(N int) {
// TODO: automatically scale this to 1s (with a min) in the CLI.
r.workFactor = N
}
func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
salt := make([]byte, 16)
if _, err := rand.Read(salt[:]); err != nil {
return nil, err
}
N := r.workFactor
l := &format.Recipient{
Type: "scrypt",
Args: []string{format.EncodeToString(salt), strconv.Itoa(N)},
}
k, err := scrypt.Key(r.password, salt, N, 8, 1, chacha20poly1305.KeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
}
wrappedKey, err := aeadEncrypt(k, fileKey)
if err != nil {
return nil, err
}
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
return l, nil
}
type ScryptIdentity struct {
password []byte
maxWorkFactor int
}
var _ Identity = &ScryptIdentity{}
func (*ScryptIdentity) Type() string { return "scrypt" }
func NewScryptIdentity(password string) (*ScryptIdentity, error) {
if len(password) == 0 {
return nil, errors.New("empty scrypt password")
}
i := &ScryptIdentity{
password: []byte(password),
maxWorkFactor: 1 << 22, // 15s on a modern machine
}
return i, nil
}
func (i *ScryptIdentity) SetMaxWorkFactor(N int) {
i.maxWorkFactor = N
}
func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
if block.Type != "scrypt" {
return nil, errors.New("wrong recipient block type")
}
if len(block.Args) != 2 {
return nil, errors.New("invalid scrypt recipient block")
}
salt, err := format.DecodeString(block.Args[0])
if err != nil {
return nil, fmt.Errorf("failed to parse scrypt salt: %v", err)
}
if len(salt) != 16 {
return nil, errors.New("invalid scrypt recipient block")
}
N, err := strconv.Atoi(block.Args[1])
if err != nil {
return nil, fmt.Errorf("failed to parse scrypt work factor: %v", err)
}
if N > i.maxWorkFactor {
return nil, fmt.Errorf("scrypt work factor too large: %v", N)
}
wrappedKey, err := format.DecodeString(string(block.Body))
if err != nil {
return nil, fmt.Errorf("failed to parse scrypt recipient: %v", err)
}
k, err := scrypt.Key(i.password, salt, N, 8, 1, chacha20poly1305.KeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
}
fileKey, err := aeadDecrypt(k, wrappedKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
}
return fileKey, nil
}

118
internal/age/ssh.go Normal file
View File

@@ -0,0 +1,118 @@
package age
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"github.com/FiloSottile/age/internal/format"
"golang.org/x/crypto/ssh"
)
const oaepLabel = "age-tool.com ssh-rsa"
type SSHRSARecipient struct {
sshKey ssh.PublicKey
pubKey *rsa.PublicKey
}
var _ Recipient = &SSHRSARecipient{}
func (*SSHRSARecipient) Type() string { return "ssh-rsa" }
func NewSSHRSARecipient(pk ssh.PublicKey) (*SSHRSARecipient, error) {
if pk.Type() != "ssh-rsa" {
return nil, errors.New("SSH public key is not an RSA key")
}
r := &SSHRSARecipient{
sshKey: pk,
}
if pk, ok := pk.(ssh.CryptoPublicKey); ok {
if pk, ok := pk.CryptoPublicKey().(*rsa.PublicKey); ok {
r.pubKey = pk
} else {
return nil, errors.New("unexpected public key type")
}
} else {
return nil, errors.New("pk does not implement ssh.CryptoPublicKey")
}
return r, nil
}
func (r *SSHRSARecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
h := sha256.New()
h.Write(r.sshKey.Marshal())
hh := h.Sum(nil)
l := &format.Recipient{
Type: "ssh-rsa",
Args: []string{format.EncodeToString(hh[:4])},
}
wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader,
r.pubKey, fileKey, []byte(oaepLabel))
if err != nil {
return nil, err
}
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
return l, nil
}
type SSHRSAIdentity struct {
k *rsa.PrivateKey
sshKey ssh.PublicKey
}
var _ Identity = &SSHRSAIdentity{}
func (*SSHRSAIdentity) Type() string { return "ssh-rsa" }
func NewSSHRSAIdentity(key *rsa.PrivateKey) (*SSHRSAIdentity, error) {
s, err := ssh.NewSignerFromKey(key)
if err != nil {
return nil, err
}
i := &SSHRSAIdentity{
k: key, sshKey: s.PublicKey(),
}
return i, nil
}
func (i *SSHRSAIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
if block.Type != "ssh-rsa" {
return nil, errors.New("wrong recipient block type")
}
if len(block.Args) != 1 {
return nil, errors.New("invalid ssh-rsa recipient block")
}
hash, err := format.DecodeString(block.Args[0])
if err != nil {
return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
}
if len(hash) != 4 {
return nil, errors.New("invalid ssh-rsa recipient block")
}
wrappedKey, err := format.DecodeString(string(block.Body))
if err != nil {
return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
}
h := sha256.New()
h.Write(i.sshKey.Marshal())
hh := h.Sum(nil)
if !bytes.Equal(hh[:4], hash) {
return nil, errors.New("wrong ssh-rsa key")
}
fileKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, i.k,
wrappedKey, []byte(oaepLabel))
if err != nil {
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
}
return fileKey, nil
}

123
internal/age/x25519.go Normal file
View File

@@ -0,0 +1,123 @@
package age
import (
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"github.com/FiloSottile/age/internal/format"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
const x25519Label = "age-tool.com X25519"
type X25519Recipient struct {
theirPublicKey [32]byte
}
var _ Recipient = &X25519Recipient{}
func (*X25519Recipient) Type() string { return "X25519" }
func NewX25519Recipient(publicKey []byte) (*X25519Recipient, error) {
if len(publicKey) != 32 {
return nil, errors.New("invalid X25519 public key")
}
r := &X25519Recipient{}
copy(r.theirPublicKey[:], publicKey)
return r, nil
}
func (r *X25519Recipient) Wrap(fileKey []byte) (*format.Recipient, error) {
var ephemeral, ourPublicKey [32]byte
if _, err := rand.Read(ephemeral[:]); err != nil {
return nil, err
}
curve25519.ScalarBaseMult(&ourPublicKey, &ephemeral)
var sharedSecret [32]byte
curve25519.ScalarMult(&sharedSecret, &ephemeral, &r.theirPublicKey)
l := &format.Recipient{
Type: "X25519",
Args: []string{format.EncodeToString(ourPublicKey[:])},
}
salt := make([]byte, 0, 32*2)
salt = append(salt, ourPublicKey[:]...)
salt = append(salt, r.theirPublicKey[:]...)
h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
wrappingKey := make([]byte, chacha20poly1305.KeySize)
if _, err := io.ReadFull(h, wrappingKey); err != nil {
return nil, err
}
wrappedKey, err := aeadEncrypt(wrappingKey, fileKey)
if err != nil {
return nil, err
}
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
return l, nil
}
type X25519Identity struct {
secretKey, ourPublicKey [32]byte
}
var _ Identity = &X25519Identity{}
func (*X25519Identity) Type() string { return "X25519" }
func NewX25519Identity(secretKey []byte) (*X25519Identity, error) {
if len(secretKey) != 32 {
return nil, errors.New("invalid X25519 secret key")
}
i := &X25519Identity{}
copy(i.secretKey[:], secretKey)
curve25519.ScalarBaseMult(&i.ourPublicKey, &i.secretKey)
return i, nil
}
func (i *X25519Identity) Unwrap(block *format.Recipient) ([]byte, error) {
if block.Type != "X25519" {
return nil, errors.New("wrong recipient block type")
}
if len(block.Args) != 1 {
return nil, errors.New("invalid X25519 recipient block")
}
publicKey, err := format.DecodeString(block.Args[0])
if err != nil {
return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
}
if len(publicKey) != 32 {
return nil, errors.New("invalid X25519 recipient block")
}
wrappedKey, err := format.DecodeString(string(block.Body))
if err != nil {
return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
}
var sharedSecret, theirPublicKey [32]byte
copy(theirPublicKey[:], publicKey)
curve25519.ScalarMult(&sharedSecret, &i.secretKey, &theirPublicKey)
salt := make([]byte, 0, 32*2)
salt = append(salt, theirPublicKey[:]...)
salt = append(salt, i.ourPublicKey[:]...)
h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
wrappingKey := make([]byte, chacha20poly1305.KeySize)
if _, err := io.ReadFull(h, wrappingKey); err != nil {
return nil, err
}
fileKey, err := aeadDecrypt(wrappingKey, wrappedKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
}
return fileKey, nil
}

View File

@@ -24,7 +24,7 @@ type Recipient struct {
var b64 = base64.RawURLEncoding.Strict()
func decodeString(s string) ([]byte, error) {
func DecodeString(s string) ([]byte, error) {
// CR and LF are ignored by DecodeString. LF is handled by the parser,
// but CR can introduce malleability.
if strings.Contains(s, "\r") {
@@ -33,12 +33,14 @@ func decodeString(s string) ([]byte, error) {
return b64.DecodeString(s)
}
var EncodeToString = b64.EncodeToString // TODO: wrap lines
const intro = "This is a file encrypted with age-tool.com, version 1\n"
var recipientPrefix = []byte("->")
var footerPrefix = []byte("---")
func (h *Header) Marshal(w io.Writer) error {
func (h *Header) MarshalWithoutMAC(w io.Writer) error {
if _, err := io.WriteString(w, intro); err != nil {
return err
}
@@ -54,12 +56,21 @@ func (h *Header) Marshal(w io.Writer) error {
if _, err := io.WriteString(w, "\n"); err != nil {
return err
}
// TODO: check that Body ends with a newline.
if _, err := w.Write(r.Body); err != nil {
return err
}
}
_, err := fmt.Fprintf(w, "%s %s", footerPrefix, h.AEAD)
return err
}
func (h *Header) Marshal(w io.Writer) error {
if err := h.MarshalWithoutMAC(w); err != nil {
return err
}
mac := b64.EncodeToString(h.MAC)
_, err := fmt.Fprintf(w, "%s %s %s\n", footerPrefix, h.AEAD, mac)
_, err := fmt.Fprintf(w, " %s\n", mac)
return err
}
@@ -100,7 +111,7 @@ func Parse(input io.Reader) (*Header, io.Reader, error) {
return nil, nil, errorf("malformed closing line: %q", line)
}
h.AEAD = args[0]
h.MAC, err = decodeString(args[1])
h.MAC, err = DecodeString(args[1])
if err != nil {
return nil, nil, errorf("malformed closing line %q: %v", line, err)
}

View File

@@ -1,3 +1,5 @@
// +build gofuzz
package format
import (

208
internal/stream/stream.go Normal file
View File

@@ -0,0 +1,208 @@
package stream
import (
"crypto/cipher"
"errors"
"io"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
)
const ChunkSize = 64 * 1024
type Reader struct {
a cipher.AEAD
src io.Reader
unread []byte // decrypted but unread data, backed by buf
buf [encChunkSize]byte
err error
nonce [chacha20poly1305.NonceSize]byte
}
const (
encChunkSize = ChunkSize + poly1305.TagSize
lastChunkFlag = 0x01
)
func NewReader(key []byte, src io.Reader) (*Reader, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
return &Reader{
a: aead,
src: src,
}, nil
}
func (r *Reader) Read(p []byte) (int, error) {
if len(r.unread) > 0 {
n := copy(p, r.unread)
r.unread = r.unread[n:]
return n, nil
}
if r.err != nil {
return 0, r.err
}
if len(p) == 0 {
return 0, nil
}
last, err := r.readChunk()
if err != nil {
r.err = err
return 0, err
}
n := copy(p, r.unread)
r.unread = r.unread[n:]
if last {
r.err = io.EOF
}
return n, nil
}
// readChunk reads the next chunk of ciphertext from r.c and makes in available
// in r.unread. last is true if the chunk was marked as the end of the message.
// readChunk must not be called again after returning a last chunk or an error.
func (r *Reader) readChunk() (last bool, err error) {
if len(r.unread) != 0 {
panic("stream: internal error: readChunk called with dirty buffer")
}
in := r.buf[:]
n, err := io.ReadFull(r.src, in)
switch {
case err == io.EOF:
// A message can't end without a marked chunk. This message is truncated.
return false, io.ErrUnexpectedEOF
case err == io.ErrUnexpectedEOF:
// The last chunk can be short.
in = in[:n]
last = true
setLastChunkFlag(&r.nonce)
case err != nil:
return false, err
}
outBuf := make([]byte, 0, ChunkSize)
out, err := r.a.Open(outBuf, r.nonce[:], in, nil)
if err != nil && !last {
// Check if this was a full-length final chunk.
last = true
setLastChunkFlag(&r.nonce)
out, err = r.a.Open(outBuf, r.nonce[:], in, nil)
}
if err != nil {
return false, err
}
incNonce(&r.nonce)
r.unread = r.buf[:copy(r.buf[:], out)]
return last, nil
}
func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
for i := len(nonce) - 2; i >= 0; i-- {
nonce[i]++
if nonce[i] != 0 {
break
} else if i == 0 {
// The counter is 88 bits, this is unreachable.
panic("stream: chunk counter wrapped around")
}
}
}
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
nonce[len(nonce)-1] = lastChunkFlag
}
type Writer struct {
a cipher.AEAD
dst io.Writer
unwritten []byte // backed by buf
buf [encChunkSize]byte
nonce [chacha20poly1305.NonceSize]byte
err error
}
func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
w := &Writer{
a: aead,
dst: dst,
}
w.unwritten = w.buf[:0]
return w, nil
}
func (w *Writer) Write(p []byte) (n int, err error) {
// TODO: consider refactoring with a bytes.Buffer.
if w.err != nil {
return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
total := len(p)
for len(p) > 0 {
free := ChunkSize - len(w.unwritten)
freeBuf := w.buf[len(w.unwritten) : len(w.unwritten)+free]
n := copy(freeBuf, p)
p = p[n:]
w.unwritten = w.unwritten[:len(w.unwritten)+n]
if len(w.unwritten) == ChunkSize && len(p) > 0 {
if err := w.flushChunk(notLastChunk); err != nil {
w.err = err
return 0, err
}
}
}
return total, nil
}
func (w *Writer) Close() error {
// TODO: close w.dst if it can be interface upgraded to io.Closer.
if w.err != nil {
return w.err
}
err := w.flushChunk(lastChunk)
if err != nil {
w.err = err
} else {
w.err = errors.New("stream.Writer is already closed")
}
return err
}
const (
lastChunk = true
notLastChunk = false
)
func (w *Writer) flushChunk(last bool) error {
if !last && len(w.unwritten) != ChunkSize {
panic("stream: internal error: flush called with partial chunk")
}
if last {
setLastChunkFlag(&w.nonce)
}
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
_, err := w.dst.Write(buf)
w.unwritten = w.buf[:0]
incNonce(&w.nonce)
return err
}

View File

@@ -0,0 +1,93 @@
package stream_test
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
"github.com/FiloSottile/age/internal/stream"
"golang.org/x/crypto/chacha20poly1305"
)
const cs = stream.ChunkSize
func TestRoundTrip(t *testing.T) {
for _, stepSize := range []int{512, 600, 1000, cs} {
for _, length := range []int{0, 1000, cs, cs + 100} {
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
}
}
}
func testRoundTrip(t *testing.T, stepSize, length int) {
src := make([]byte, length)
if _, err := rand.Read(src); err != nil {
t.Fatal(err)
}
buf := &bytes.Buffer{}
key := make([]byte, chacha20poly1305.KeySize)
if _, err := rand.Read(key); err != nil {
t.Fatal(err)
}
w, err := stream.NewWriter(key, buf)
if err != nil {
t.Fatal(err)
}
var n int
for n < length {
b := length - n
if b > stepSize {
b = stepSize
}
nn, err := w.Write(src[n : n+b])
if err != nil {
t.Fatal(err)
}
if nn != b {
t.Errorf("Write returned %d, expected %d", nn, b)
}
n += nn
nn, err = w.Write(src[n:n])
if err != nil {
t.Fatal(err)
}
if nn != 0 {
t.Errorf("Write returned %d, expected 0", nn)
}
}
if err := w.Close(); err != nil {
t.Error("Close returned an error:", err)
}
t.Logf("buffer size: %d", buf.Len())
r, err := stream.NewReader(key, buf)
if err != nil {
t.Fatal(err)
}
n = 0
readBuf := make([]byte, stepSize)
for n < length {
b := length - n
if b > stepSize {
b = stepSize
}
nn, err := r.Read(readBuf)
if err != nil {
t.Fatalf("Read error at index %d: %v", n, err)
}
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
}
n += nn
}
}