From eec79dc9d9a863b1f6d88869da1696768b12f00c Mon Sep 17 00:00:00 2001 From: Vikas Date: Sun, 23 Jun 2024 12:18:37 +0530 Subject: [PATCH] Improve encryption/decryption - Use crypto.rand. Relying on math.rand does not ensure true randomness. - Bas64 encode random bytes instead of generating keys using a fixed set of letters. - Based on reddit feedback: https://www.reddit.com/r/golang/comments/1d73vsl/comment/l6xmtja/ --- db/crypto.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++----- db/db.go | 8 ++++++-- db/models.go | 6 +++--- 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/db/crypto.go b/db/crypto.go index 64c1a88..b146a4a 100644 --- a/db/crypto.go +++ b/db/crypto.go @@ -4,14 +4,15 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "encoding/base64" "errors" "fmt" "io" random "math/rand" ) -func encrypt(text string, key string) ([]byte, error) { - c, err := aes.NewCipher([]byte(key)) +func encrypt(text, key []byte) ([]byte, error) { + c, err := aes.NewCipher(key) if err != nil { return nil, err } @@ -26,11 +27,11 @@ func encrypt(text string, key string) ([]byte, error) { return nil, err } - return gcm.Seal(nonce, nonce, []byte(text), nil), nil + return gcm.Seal(nonce, nonce, text, nil), nil } -func decrypt(ciphertext []byte, key string) (string, error) { - c, err := aes.NewCipher([]byte(key)) +func decrypt(ciphertext, key []byte) (string, error) { + c, err := aes.NewCipher(key) if err != nil { return "", fmt.Errorf("failed to create cipher: %w", err) } @@ -54,6 +55,48 @@ func decrypt(ciphertext []byte, key string) (string, error) { return string(plaintext), nil } +type EncryptionKey struct { + Key []byte +} + +func (k *EncryptionKey) Base64Key() string { + return base64.RawURLEncoding.EncodeToString(k.Key) +} + +func NewEncryptionKey() *EncryptionKey { + key := make([]byte, keyLength) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + panic(err) + } + + return &EncryptionKey{ + Key: key, + } +} + +func NewEncryptionKeyFromBase64(base64Key string) (*EncryptionKey, error) { + key, err := base64.RawURLEncoding.DecodeString(base64Key) + if err != nil { + return nil, err + } + + if len(key) != 32 { + return nil, errors.New("invalid key length") + } + + return &EncryptionKey{ + Key: key, + }, nil +} + +func (k *EncryptionKey) Encrypt(text string) ([]byte, error) { + return encrypt([]byte(text), k.Key) +} + +func (k *EncryptionKey) Decrypt(ciphertext []byte) (string, error) { + return decrypt(ciphertext, k.Key) +} + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") const keyLength = 32 diff --git a/db/db.go b/db/db.go index 53a8540..dcb9b36 100644 --- a/db/db.go +++ b/db/db.go @@ -113,8 +113,12 @@ func (d *DB) Decrypt(id string, key string) (string, error) { return ErrPasteNotFound } - var err error - decryptedText, err = decrypt(encryptedPaste, key) + decryptionKey, err := NewEncryptionKeyFromBase64(key) + if err != nil { + return err + } + + decryptedText, err = decryptionKey.Decrypt(encryptedPaste) if err != nil { return err } diff --git a/db/models.go b/db/models.go index 6b82c35..0ef0515 100644 --- a/db/models.go +++ b/db/models.go @@ -14,8 +14,8 @@ type Paste struct { } func NewEncryptedPaste(text string, expiresAt time.Time) (*Paste, error) { - key := randomKey() - encryptedText, err := encrypt(text, key) + key := NewEncryptionKey() + encryptedText, err := key.Encrypt(text) if err != nil { return nil, err } @@ -24,7 +24,7 @@ func NewEncryptedPaste(text string, expiresAt time.Time) (*Paste, error) { ID: randomKey(), Text: text, EncryptedBytes: encryptedText, - Key: key, + Key: key.Base64Key(), CreatedAt: time.Now(), ExpiresAt: expiresAt, }, nil