diff --git a/keycache/keycache.go b/keycache/keycache.go index 39c842a..1c26eab 100644 --- a/keycache/keycache.go +++ b/keycache/keycache.go @@ -7,10 +7,12 @@ package keycache import ( "crypto/aes" + "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha1" "errors" + "github.com/cloudflare/redoctober/ecdh" "github.com/cloudflare/redoctober/passvault" "log" "time" @@ -28,6 +30,7 @@ type ActiveUser struct { aesKey []byte rsaKey rsa.PrivateKey + eccKey *ecdsa.PrivateKey } // matchUser returns the matching active user if present @@ -93,6 +96,8 @@ func AddKeyFromRecord(record passvault.PasswordRecord, name string, password str current.aesKey, err = record.GetKeyAES(password) case passvault.RSARecord: current.rsaKey, err = record.GetKeyRSA(password) + case passvault.ECCRecord: + current.eccKey, err = record.GetKeyECC(password) default: err = errors.New("Unknown record type") } @@ -113,7 +118,8 @@ func AddKeyFromRecord(record passvault.PasswordRecord, name string, password str // EncryptKey encrypts a 16 byte key using the cached key corresponding to name. // For AES keys, use the cached key. -// For RSA keys, the cache is not necessary use the override key instead. +// For RSA and EC keys, the cache is not necessary; use the override +// key instead. func EncryptKey(in []byte, name string, override []byte) (out []byte, err error) { Refresh() @@ -150,9 +156,10 @@ func EncryptKey(in []byte, name string, override []byte) (out []byte, err error) // DecryptKey decrypts a 16 byte key using the key corresponding to the name parameter // for AES keys, the cached AES key is used directly to decrypt in -// for RSA keys, the cached RSA key is used to decrypt the rsaEncryptedKey -// which is then used to decrypt the input buffer. -func DecryptKey(in []byte, name string, rsaEncryptedKey []byte) (out []byte, err error) { +// for RSA and EC keys, the cached RSA/EC key is used to decrypt +// the pubEncryptedKey which is then used to decrypt the input +// buffer. +func DecryptKey(in []byte, name string, pubEncryptedKey []byte) (out []byte, err error) { Refresh() decryptKey, ok := matchUser(name) @@ -168,12 +175,18 @@ func DecryptKey(in []byte, name string, rsaEncryptedKey []byte) (out []byte, err aesKey = decryptKey.aesKey case passvault.RSARecord: - // extract the aes key from the rsaEncryptedKey - aesKey, err = rsa.DecryptOAEP(sha1.New(), rand.Reader, &decryptKey.rsaKey, rsaEncryptedKey, nil) + // extract the aes key from the pubEncryptedKey + aesKey, err = rsa.DecryptOAEP(sha1.New(), rand.Reader, &decryptKey.rsaKey, pubEncryptedKey, nil) if err != nil { return out, err } + case passvault.ECCRecord: + // extract the aes key from the pubEncryptedKey + aesKey, err = ecdh.Decrypt(decryptKey.eccKey, pubEncryptedKey) + if err != nil { + return out, err + } default: return nil, errors.New("unknown type") }