mirror of
https://github.com/FiloSottile/age.git
synced 2026-01-10 05:37:20 +00:00
internal: implement STREAM, key exchange, encryption and decryption
Developed live over 6 hours of streaming on Twitch. https://twitter.com/FiloSottile/status/1180875486911766528
This commit is contained in:
2
go.mod
2
go.mod
@@ -1,3 +1,5 @@
|
||||
module github.com/FiloSottile/age
|
||||
|
||||
go 1.12
|
||||
|
||||
require golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
|
||||
|
||||
8
go.sum
Normal file
8
go.sum
Normal file
@@ -0,0 +1,8 @@
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
118
internal/age/age.go
Normal file
118
internal/age/age.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package age
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/FiloSottile/age/internal/format"
|
||||
"github.com/FiloSottile/age/internal/stream"
|
||||
)
|
||||
|
||||
type Identity interface {
|
||||
Type() string
|
||||
Unwrap(block *format.Recipient) (fileKey []byte, err error)
|
||||
}
|
||||
|
||||
type Recipient interface {
|
||||
Type() string
|
||||
Wrap(fileKey []byte) (*format.Recipient, error)
|
||||
}
|
||||
|
||||
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
|
||||
if len(recipients) == 0 {
|
||||
return nil, errors.New("no recipients specified")
|
||||
}
|
||||
|
||||
fileKey := make([]byte, 16)
|
||||
if _, err := rand.Read(fileKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hdr := &format.Header{}
|
||||
// TODO: remove the AEAD marker from v1.
|
||||
hdr.AEAD = "ChaChaPoly"
|
||||
for i, r := range recipients {
|
||||
if r.Type() == "scrypt" && len(recipients) != 1 {
|
||||
return nil, errors.New("an scrypt recipient must be the only one")
|
||||
}
|
||||
|
||||
block, err := r.Wrap(fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
|
||||
}
|
||||
hdr.Recipients = append(hdr.Recipients, block)
|
||||
}
|
||||
if mac, err := headerMAC(fileKey, hdr); err != nil {
|
||||
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
|
||||
} else {
|
||||
hdr.MAC = mac
|
||||
}
|
||||
if err := hdr.Marshal(dst); err != nil {
|
||||
return nil, fmt.Errorf("failed to write header: %v", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, 16)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := dst.Write(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to write nonce: %v", err)
|
||||
}
|
||||
|
||||
return stream.NewWriter(streamKey(fileKey, nonce), dst)
|
||||
}
|
||||
|
||||
func Decrypt(src io.Reader, identities ...Identity) (io.Reader, error) {
|
||||
if len(identities) == 0 {
|
||||
return nil, errors.New("no identities specified")
|
||||
}
|
||||
|
||||
hdr, payload, err := format.Parse(src)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %v", err)
|
||||
}
|
||||
if hdr.AEAD != "ChaChaPoly" {
|
||||
return nil, fmt.Errorf("unsupported AEAD: %v", hdr.AEAD)
|
||||
}
|
||||
if len(hdr.Recipients) > 20 {
|
||||
return nil, errors.New("too many recipients")
|
||||
}
|
||||
|
||||
var fileKey []byte
|
||||
RecipientsLoop:
|
||||
for _, r := range hdr.Recipients {
|
||||
if r.Type == "scrypt" && len(hdr.Recipients) != 1 {
|
||||
return nil, errors.New("an scrypt recipient must be the only one")
|
||||
}
|
||||
for _, i := range identities {
|
||||
|
||||
if i.Type() != r.Type {
|
||||
continue
|
||||
}
|
||||
|
||||
fileKey, err = i.Unwrap(r)
|
||||
if err == nil {
|
||||
break RecipientsLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
if fileKey == nil {
|
||||
return nil, errors.New("no identity matched a recipient")
|
||||
}
|
||||
|
||||
if mac, err := headerMAC(fileKey, hdr); err != nil {
|
||||
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
|
||||
} else if !hmac.Equal(mac, hdr.MAC) {
|
||||
return nil, errors.New("bad header MAC")
|
||||
}
|
||||
|
||||
nonce := make([]byte, 16)
|
||||
if _, err := io.ReadFull(payload, nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to read nonce: %v", err)
|
||||
}
|
||||
|
||||
return stream.NewReader(streamKey(fileKey, nonce), payload)
|
||||
}
|
||||
103
internal/age/age_test.go
Normal file
103
internal/age/age_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package age_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/FiloSottile/age/internal/age"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
const helloWorld = "Hello, Twitch!"
|
||||
|
||||
func TestEncryptDecryptX25519(t *testing.T) {
|
||||
var secretKeyA, publicKeyA, secretKeyB, publicKeyB [32]byte
|
||||
if _, err := rand.Read(secretKeyA[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := rand.Read(secretKeyB[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curve25519.ScalarBaseMult(&publicKeyA, &secretKeyA)
|
||||
curve25519.ScalarBaseMult(&publicKeyB, &secretKeyB)
|
||||
|
||||
rA, err := age.NewX25519Recipient(publicKeyA[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rB, err := age.NewX25519Recipient(publicKeyB[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := age.Encrypt(buf, rA, rB)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.WriteString(w, helloWorld); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("%s", buf.Bytes())
|
||||
|
||||
i, err := age.NewX25519Identity(secretKeyB[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := age.Decrypt(buf, i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
outBytes, err := ioutil.ReadAll(out)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(outBytes) != helloWorld {
|
||||
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptScrypt(t *testing.T) {
|
||||
password := "twitch.tv/filosottile"
|
||||
|
||||
r, err := age.NewScryptRecipient(password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.SetWorkFactor(1 << 15)
|
||||
buf := &bytes.Buffer{}
|
||||
w, err := age.Encrypt(buf, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.WriteString(w, helloWorld); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("%s", buf.Bytes())
|
||||
|
||||
i, err := age.NewScryptIdentity(password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := age.Decrypt(buf, i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
outBytes, err := ioutil.ReadAll(out)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(outBytes) != helloWorld {
|
||||
t.Errorf("wrong data: %q, excepted %q", outBytes, helloWorld)
|
||||
}
|
||||
}
|
||||
51
internal/age/primitives.go
Normal file
51
internal/age/primitives.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package age
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"github.com/FiloSottile/age/internal/format"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
func aeadEncrypt(key, plaintext []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, chacha20poly1305.NonceSize)
|
||||
return aead.Seal(nil, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func aeadDecrypt(key, ciphertext []byte) ([]byte, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, chacha20poly1305.NonceSize)
|
||||
return aead.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
|
||||
func headerMAC(fileKey []byte, hdr *format.Header) ([]byte, error) {
|
||||
h := hkdf.New(sha256.New, fileKey, nil, []byte("header"))
|
||||
hmacKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(h, hmacKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hh := hmac.New(sha256.New, hmacKey)
|
||||
if err := hdr.MarshalWithoutMAC(hh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hh.Sum(nil), nil
|
||||
}
|
||||
|
||||
func streamKey(fileKey, nonce []byte) []byte {
|
||||
h := hkdf.New(sha256.New, fileKey, nonce, []byte("payload"))
|
||||
streamKey := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := io.ReadFull(h, streamKey); err != nil {
|
||||
panic("age: internal error: failed to read from HKDF: " + err.Error())
|
||||
}
|
||||
return streamKey
|
||||
}
|
||||
131
internal/age/recipients_test.go
Normal file
131
internal/age/recipients_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package age_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/FiloSottile/age/internal/age"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestX25519RoundTrip(t *testing.T) {
|
||||
var secretKey, publicKey, fileKey [32]byte
|
||||
if _, err := rand.Read(secretKey[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := rand.Read(fileKey[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curve25519.ScalarBaseMult(&publicKey, &secretKey)
|
||||
|
||||
r, err := age.NewX25519Recipient(publicKey[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
i, err := age.NewX25519Identity(secretKey[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Type() != i.Type() || r.Type() != "X25519" {
|
||||
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
|
||||
}
|
||||
|
||||
block, err := r.Wrap(fileKey[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%#v", block)
|
||||
|
||||
out, err := i.Unwrap(block)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(fileKey[:], out) {
|
||||
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestScryptRoundTrip(t *testing.T) {
|
||||
password := "twitch.tv/filosottile"
|
||||
|
||||
r, err := age.NewScryptRecipient(password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.SetWorkFactor(1 << 15)
|
||||
i, err := age.NewScryptIdentity(password)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Type() != i.Type() || r.Type() != "scrypt" {
|
||||
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
|
||||
}
|
||||
|
||||
fileKey := make([]byte, 16)
|
||||
if _, err := rand.Read(fileKey[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
block, err := r.Wrap(fileKey[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%#v", block)
|
||||
|
||||
out, err := i.Unwrap(block)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(fileKey[:], out) {
|
||||
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHRSARoundTrip(t *testing.T) {
|
||||
pk, err := rsa.GenerateKey(rand.Reader, 768)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pub, err := ssh.NewPublicKey(&pk.PublicKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := age.NewSSHRSARecipient(pub)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
i, err := age.NewSSHRSAIdentity(pk)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if r.Type() != i.Type() || r.Type() != "ssh-rsa" {
|
||||
t.Errorf("invalid Type values: %v, %v", r.Type(), i.Type())
|
||||
}
|
||||
|
||||
fileKey := make([]byte, 16)
|
||||
if _, err := rand.Read(fileKey[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
block, err := r.Wrap(fileKey[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%#v", block)
|
||||
|
||||
out, err := i.Unwrap(block)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(fileKey[:], out) {
|
||||
t.Errorf("invalid output: %x, expected %x", out, fileKey[:])
|
||||
}
|
||||
}
|
||||
125
internal/age/scrypt.go
Normal file
125
internal/age/scrypt.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package age
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/FiloSottile/age/internal/format"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
type ScryptRecipient struct {
|
||||
password []byte
|
||||
workFactor int
|
||||
}
|
||||
|
||||
var _ Recipient = &ScryptRecipient{}
|
||||
|
||||
func (*ScryptRecipient) Type() string { return "scrypt" }
|
||||
|
||||
func NewScryptRecipient(password string) (*ScryptRecipient, error) {
|
||||
if len(password) == 0 {
|
||||
return nil, errors.New("empty scrypt password")
|
||||
}
|
||||
r := &ScryptRecipient{
|
||||
password: []byte(password),
|
||||
workFactor: 1 << 18, // 1s on a modern machine
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *ScryptRecipient) SetWorkFactor(N int) {
|
||||
// TODO: automatically scale this to 1s (with a min) in the CLI.
|
||||
r.workFactor = N
|
||||
}
|
||||
|
||||
func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
N := r.workFactor
|
||||
l := &format.Recipient{
|
||||
Type: "scrypt",
|
||||
Args: []string{format.EncodeToString(salt), strconv.Itoa(N)},
|
||||
}
|
||||
|
||||
k, err := scrypt.Key(r.password, salt, N, 8, 1, chacha20poly1305.KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
|
||||
}
|
||||
|
||||
wrappedKey, err := aeadEncrypt(k, fileKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type ScryptIdentity struct {
|
||||
password []byte
|
||||
maxWorkFactor int
|
||||
}
|
||||
|
||||
var _ Identity = &ScryptIdentity{}
|
||||
|
||||
func (*ScryptIdentity) Type() string { return "scrypt" }
|
||||
|
||||
func NewScryptIdentity(password string) (*ScryptIdentity, error) {
|
||||
if len(password) == 0 {
|
||||
return nil, errors.New("empty scrypt password")
|
||||
}
|
||||
i := &ScryptIdentity{
|
||||
password: []byte(password),
|
||||
maxWorkFactor: 1 << 22, // 15s on a modern machine
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *ScryptIdentity) SetMaxWorkFactor(N int) {
|
||||
i.maxWorkFactor = N
|
||||
}
|
||||
|
||||
func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
|
||||
if block.Type != "scrypt" {
|
||||
return nil, errors.New("wrong recipient block type")
|
||||
}
|
||||
if len(block.Args) != 2 {
|
||||
return nil, errors.New("invalid scrypt recipient block")
|
||||
}
|
||||
salt, err := format.DecodeString(block.Args[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scrypt salt: %v", err)
|
||||
}
|
||||
if len(salt) != 16 {
|
||||
return nil, errors.New("invalid scrypt recipient block")
|
||||
}
|
||||
N, err := strconv.Atoi(block.Args[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scrypt work factor: %v", err)
|
||||
}
|
||||
if N > i.maxWorkFactor {
|
||||
return nil, fmt.Errorf("scrypt work factor too large: %v", N)
|
||||
}
|
||||
wrappedKey, err := format.DecodeString(string(block.Body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scrypt recipient: %v", err)
|
||||
}
|
||||
|
||||
k, err := scrypt.Key(i.password, salt, N, 8, 1, chacha20poly1305.KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate scrypt hash: %v", err)
|
||||
}
|
||||
|
||||
fileKey, err := aeadDecrypt(k, wrappedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
|
||||
}
|
||||
return fileKey, nil
|
||||
}
|
||||
118
internal/age/ssh.go
Normal file
118
internal/age/ssh.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package age
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/FiloSottile/age/internal/format"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const oaepLabel = "age-tool.com ssh-rsa"
|
||||
|
||||
type SSHRSARecipient struct {
|
||||
sshKey ssh.PublicKey
|
||||
pubKey *rsa.PublicKey
|
||||
}
|
||||
|
||||
var _ Recipient = &SSHRSARecipient{}
|
||||
|
||||
func (*SSHRSARecipient) Type() string { return "ssh-rsa" }
|
||||
|
||||
func NewSSHRSARecipient(pk ssh.PublicKey) (*SSHRSARecipient, error) {
|
||||
if pk.Type() != "ssh-rsa" {
|
||||
return nil, errors.New("SSH public key is not an RSA key")
|
||||
}
|
||||
r := &SSHRSARecipient{
|
||||
sshKey: pk,
|
||||
}
|
||||
|
||||
if pk, ok := pk.(ssh.CryptoPublicKey); ok {
|
||||
if pk, ok := pk.CryptoPublicKey().(*rsa.PublicKey); ok {
|
||||
r.pubKey = pk
|
||||
} else {
|
||||
return nil, errors.New("unexpected public key type")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("pk does not implement ssh.CryptoPublicKey")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *SSHRSARecipient) Wrap(fileKey []byte) (*format.Recipient, error) {
|
||||
h := sha256.New()
|
||||
h.Write(r.sshKey.Marshal())
|
||||
hh := h.Sum(nil)
|
||||
|
||||
l := &format.Recipient{
|
||||
Type: "ssh-rsa",
|
||||
Args: []string{format.EncodeToString(hh[:4])},
|
||||
}
|
||||
|
||||
wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader,
|
||||
r.pubKey, fileKey, []byte(oaepLabel))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type SSHRSAIdentity struct {
|
||||
k *rsa.PrivateKey
|
||||
sshKey ssh.PublicKey
|
||||
}
|
||||
|
||||
var _ Identity = &SSHRSAIdentity{}
|
||||
|
||||
func (*SSHRSAIdentity) Type() string { return "ssh-rsa" }
|
||||
|
||||
func NewSSHRSAIdentity(key *rsa.PrivateKey) (*SSHRSAIdentity, error) {
|
||||
s, err := ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i := &SSHRSAIdentity{
|
||||
k: key, sshKey: s.PublicKey(),
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *SSHRSAIdentity) Unwrap(block *format.Recipient) ([]byte, error) {
|
||||
if block.Type != "ssh-rsa" {
|
||||
return nil, errors.New("wrong recipient block type")
|
||||
}
|
||||
if len(block.Args) != 1 {
|
||||
return nil, errors.New("invalid ssh-rsa recipient block")
|
||||
}
|
||||
hash, err := format.DecodeString(block.Args[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
|
||||
}
|
||||
if len(hash) != 4 {
|
||||
return nil, errors.New("invalid ssh-rsa recipient block")
|
||||
}
|
||||
wrappedKey, err := format.DecodeString(string(block.Body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ssh-rsa recipient: %v", err)
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
h.Write(i.sshKey.Marshal())
|
||||
hh := h.Sum(nil)
|
||||
if !bytes.Equal(hh[:4], hash) {
|
||||
return nil, errors.New("wrong ssh-rsa key")
|
||||
}
|
||||
|
||||
fileKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, i.k,
|
||||
wrappedKey, []byte(oaepLabel))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
|
||||
}
|
||||
return fileKey, nil
|
||||
}
|
||||
123
internal/age/x25519.go
Normal file
123
internal/age/x25519.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package age
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/FiloSottile/age/internal/format"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const x25519Label = "age-tool.com X25519"
|
||||
|
||||
type X25519Recipient struct {
|
||||
theirPublicKey [32]byte
|
||||
}
|
||||
|
||||
var _ Recipient = &X25519Recipient{}
|
||||
|
||||
func (*X25519Recipient) Type() string { return "X25519" }
|
||||
|
||||
func NewX25519Recipient(publicKey []byte) (*X25519Recipient, error) {
|
||||
if len(publicKey) != 32 {
|
||||
return nil, errors.New("invalid X25519 public key")
|
||||
}
|
||||
r := &X25519Recipient{}
|
||||
copy(r.theirPublicKey[:], publicKey)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *X25519Recipient) Wrap(fileKey []byte) (*format.Recipient, error) {
|
||||
var ephemeral, ourPublicKey [32]byte
|
||||
if _, err := rand.Read(ephemeral[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
curve25519.ScalarBaseMult(&ourPublicKey, &ephemeral)
|
||||
|
||||
var sharedSecret [32]byte
|
||||
curve25519.ScalarMult(&sharedSecret, &ephemeral, &r.theirPublicKey)
|
||||
|
||||
l := &format.Recipient{
|
||||
Type: "X25519",
|
||||
Args: []string{format.EncodeToString(ourPublicKey[:])},
|
||||
}
|
||||
|
||||
salt := make([]byte, 0, 32*2)
|
||||
salt = append(salt, ourPublicKey[:]...)
|
||||
salt = append(salt, r.theirPublicKey[:]...)
|
||||
h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
|
||||
wrappingKey := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := io.ReadFull(h, wrappingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrappedKey, err := aeadEncrypt(wrappingKey, fileKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Body = []byte(format.EncodeToString(wrappedKey) + "\n")
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type X25519Identity struct {
|
||||
secretKey, ourPublicKey [32]byte
|
||||
}
|
||||
|
||||
var _ Identity = &X25519Identity{}
|
||||
|
||||
func (*X25519Identity) Type() string { return "X25519" }
|
||||
|
||||
func NewX25519Identity(secretKey []byte) (*X25519Identity, error) {
|
||||
if len(secretKey) != 32 {
|
||||
return nil, errors.New("invalid X25519 secret key")
|
||||
}
|
||||
i := &X25519Identity{}
|
||||
copy(i.secretKey[:], secretKey)
|
||||
curve25519.ScalarBaseMult(&i.ourPublicKey, &i.secretKey)
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *X25519Identity) Unwrap(block *format.Recipient) ([]byte, error) {
|
||||
if block.Type != "X25519" {
|
||||
return nil, errors.New("wrong recipient block type")
|
||||
}
|
||||
if len(block.Args) != 1 {
|
||||
return nil, errors.New("invalid X25519 recipient block")
|
||||
}
|
||||
publicKey, err := format.DecodeString(block.Args[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
|
||||
}
|
||||
if len(publicKey) != 32 {
|
||||
return nil, errors.New("invalid X25519 recipient block")
|
||||
}
|
||||
wrappedKey, err := format.DecodeString(string(block.Body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse X25519 recipient: %v", err)
|
||||
}
|
||||
|
||||
var sharedSecret, theirPublicKey [32]byte
|
||||
copy(theirPublicKey[:], publicKey)
|
||||
curve25519.ScalarMult(&sharedSecret, &i.secretKey, &theirPublicKey)
|
||||
|
||||
salt := make([]byte, 0, 32*2)
|
||||
salt = append(salt, theirPublicKey[:]...)
|
||||
salt = append(salt, i.ourPublicKey[:]...)
|
||||
h := hkdf.New(sha256.New, sharedSecret[:], salt, []byte(x25519Label))
|
||||
wrappingKey := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := io.ReadFull(h, wrappingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileKey, err := aeadDecrypt(wrappingKey, wrappedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file key: %v", err)
|
||||
}
|
||||
return fileKey, nil
|
||||
}
|
||||
@@ -24,7 +24,7 @@ type Recipient struct {
|
||||
|
||||
var b64 = base64.RawURLEncoding.Strict()
|
||||
|
||||
func decodeString(s string) ([]byte, error) {
|
||||
func DecodeString(s string) ([]byte, error) {
|
||||
// CR and LF are ignored by DecodeString. LF is handled by the parser,
|
||||
// but CR can introduce malleability.
|
||||
if strings.Contains(s, "\r") {
|
||||
@@ -33,12 +33,14 @@ func decodeString(s string) ([]byte, error) {
|
||||
return b64.DecodeString(s)
|
||||
}
|
||||
|
||||
var EncodeToString = b64.EncodeToString // TODO: wrap lines
|
||||
|
||||
const intro = "This is a file encrypted with age-tool.com, version 1\n"
|
||||
|
||||
var recipientPrefix = []byte("->")
|
||||
var footerPrefix = []byte("---")
|
||||
|
||||
func (h *Header) Marshal(w io.Writer) error {
|
||||
func (h *Header) MarshalWithoutMAC(w io.Writer) error {
|
||||
if _, err := io.WriteString(w, intro); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -54,12 +56,21 @@ func (h *Header) Marshal(w io.Writer) error {
|
||||
if _, err := io.WriteString(w, "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: check that Body ends with a newline.
|
||||
if _, err := w.Write(r.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := fmt.Fprintf(w, "%s %s", footerPrefix, h.AEAD)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Header) Marshal(w io.Writer) error {
|
||||
if err := h.MarshalWithoutMAC(w); err != nil {
|
||||
return err
|
||||
}
|
||||
mac := b64.EncodeToString(h.MAC)
|
||||
_, err := fmt.Fprintf(w, "%s %s %s\n", footerPrefix, h.AEAD, mac)
|
||||
_, err := fmt.Fprintf(w, " %s\n", mac)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -100,7 +111,7 @@ func Parse(input io.Reader) (*Header, io.Reader, error) {
|
||||
return nil, nil, errorf("malformed closing line: %q", line)
|
||||
}
|
||||
h.AEAD = args[0]
|
||||
h.MAC, err = decodeString(args[1])
|
||||
h.MAC, err = DecodeString(args[1])
|
||||
if err != nil {
|
||||
return nil, nil, errorf("malformed closing line %q: %v", line, err)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// +build gofuzz
|
||||
|
||||
package format
|
||||
|
||||
import (
|
||||
|
||||
208
internal/stream/stream.go
Normal file
208
internal/stream/stream.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
)
|
||||
|
||||
const ChunkSize = 64 * 1024
|
||||
|
||||
type Reader struct {
|
||||
a cipher.AEAD
|
||||
src io.Reader
|
||||
|
||||
unread []byte // decrypted but unread data, backed by buf
|
||||
buf [encChunkSize]byte
|
||||
|
||||
err error
|
||||
nonce [chacha20poly1305.NonceSize]byte
|
||||
}
|
||||
|
||||
const (
|
||||
encChunkSize = ChunkSize + poly1305.TagSize
|
||||
lastChunkFlag = 0x01
|
||||
)
|
||||
|
||||
func NewReader(key []byte, src io.Reader) (*Reader, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Reader{
|
||||
a: aead,
|
||||
src: src,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
if len(r.unread) > 0 {
|
||||
n := copy(p, r.unread)
|
||||
r.unread = r.unread[n:]
|
||||
return n, nil
|
||||
}
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
last, err := r.readChunk()
|
||||
if err != nil {
|
||||
r.err = err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n := copy(p, r.unread)
|
||||
r.unread = r.unread[n:]
|
||||
|
||||
if last {
|
||||
r.err = io.EOF
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// readChunk reads the next chunk of ciphertext from r.c and makes in available
|
||||
// in r.unread. last is true if the chunk was marked as the end of the message.
|
||||
// readChunk must not be called again after returning a last chunk or an error.
|
||||
func (r *Reader) readChunk() (last bool, err error) {
|
||||
if len(r.unread) != 0 {
|
||||
panic("stream: internal error: readChunk called with dirty buffer")
|
||||
}
|
||||
|
||||
in := r.buf[:]
|
||||
n, err := io.ReadFull(r.src, in)
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
// A message can't end without a marked chunk. This message is truncated.
|
||||
return false, io.ErrUnexpectedEOF
|
||||
case err == io.ErrUnexpectedEOF:
|
||||
// The last chunk can be short.
|
||||
in = in[:n]
|
||||
last = true
|
||||
setLastChunkFlag(&r.nonce)
|
||||
case err != nil:
|
||||
return false, err
|
||||
}
|
||||
|
||||
outBuf := make([]byte, 0, ChunkSize)
|
||||
out, err := r.a.Open(outBuf, r.nonce[:], in, nil)
|
||||
if err != nil && !last {
|
||||
// Check if this was a full-length final chunk.
|
||||
last = true
|
||||
setLastChunkFlag(&r.nonce)
|
||||
out, err = r.a.Open(outBuf, r.nonce[:], in, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
incNonce(&r.nonce)
|
||||
r.unread = r.buf[:copy(r.buf[:], out)]
|
||||
return last, nil
|
||||
}
|
||||
|
||||
func incNonce(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
for i := len(nonce) - 2; i >= 0; i-- {
|
||||
nonce[i]++
|
||||
if nonce[i] != 0 {
|
||||
break
|
||||
} else if i == 0 {
|
||||
// The counter is 88 bits, this is unreachable.
|
||||
panic("stream: chunk counter wrapped around")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setLastChunkFlag(nonce *[chacha20poly1305.NonceSize]byte) {
|
||||
nonce[len(nonce)-1] = lastChunkFlag
|
||||
}
|
||||
|
||||
type Writer struct {
|
||||
a cipher.AEAD
|
||||
dst io.Writer
|
||||
unwritten []byte // backed by buf
|
||||
buf [encChunkSize]byte
|
||||
nonce [chacha20poly1305.NonceSize]byte
|
||||
err error
|
||||
}
|
||||
|
||||
func NewWriter(key []byte, dst io.Writer) (*Writer, error) {
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := &Writer{
|
||||
a: aead,
|
||||
dst: dst,
|
||||
}
|
||||
w.unwritten = w.buf[:0]
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
// TODO: consider refactoring with a bytes.Buffer.
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
total := len(p)
|
||||
for len(p) > 0 {
|
||||
free := ChunkSize - len(w.unwritten)
|
||||
freeBuf := w.buf[len(w.unwritten) : len(w.unwritten)+free]
|
||||
n := copy(freeBuf, p)
|
||||
p = p[n:]
|
||||
w.unwritten = w.unwritten[:len(w.unwritten)+n]
|
||||
|
||||
if len(w.unwritten) == ChunkSize && len(p) > 0 {
|
||||
if err := w.flushChunk(notLastChunk); err != nil {
|
||||
w.err = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (w *Writer) Close() error {
|
||||
// TODO: close w.dst if it can be interface upgraded to io.Closer.
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
|
||||
err := w.flushChunk(lastChunk)
|
||||
if err != nil {
|
||||
w.err = err
|
||||
} else {
|
||||
w.err = errors.New("stream.Writer is already closed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
lastChunk = true
|
||||
notLastChunk = false
|
||||
)
|
||||
|
||||
func (w *Writer) flushChunk(last bool) error {
|
||||
if !last && len(w.unwritten) != ChunkSize {
|
||||
panic("stream: internal error: flush called with partial chunk")
|
||||
}
|
||||
|
||||
if last {
|
||||
setLastChunkFlag(&w.nonce)
|
||||
}
|
||||
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
|
||||
_, err := w.dst.Write(buf)
|
||||
w.unwritten = w.buf[:0]
|
||||
incNonce(&w.nonce)
|
||||
return err
|
||||
}
|
||||
93
internal/stream/stream_test.go
Normal file
93
internal/stream/stream_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/FiloSottile/age/internal/stream"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
const cs = stream.ChunkSize
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
for _, stepSize := range []int{512, 600, 1000, cs} {
|
||||
for _, length := range []int{0, 1000, cs, cs + 100} {
|
||||
t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize),
|
||||
func(t *testing.T) { testRoundTrip(t, stepSize, length) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testRoundTrip(t *testing.T, stepSize, length int) {
|
||||
src := make([]byte, length)
|
||||
if _, err := rand.Read(src); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w, err := stream.NewWriter(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
for n < length {
|
||||
b := length - n
|
||||
if b > stepSize {
|
||||
b = stepSize
|
||||
}
|
||||
nn, err := w.Write(src[n : n+b])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != b {
|
||||
t.Errorf("Write returned %d, expected %d", nn, b)
|
||||
}
|
||||
n += nn
|
||||
|
||||
nn, err = w.Write(src[n:n])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nn != 0 {
|
||||
t.Errorf("Write returned %d, expected 0", nn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Error("Close returned an error:", err)
|
||||
}
|
||||
|
||||
t.Logf("buffer size: %d", buf.Len())
|
||||
|
||||
r, err := stream.NewReader(key, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n = 0
|
||||
readBuf := make([]byte, stepSize)
|
||||
for n < length {
|
||||
b := length - n
|
||||
if b > stepSize {
|
||||
b = stepSize
|
||||
}
|
||||
nn, err := r.Read(readBuf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read error at index %d: %v", n, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(readBuf[:nn], src[n:n+nn]) {
|
||||
t.Errorf("wrong data at indexes %d - %d", n, n+nn)
|
||||
}
|
||||
|
||||
n += nn
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user