Add db and crypto tests; fix minor bugs

This commit is contained in:
Vikas
2024-06-23 15:02:26 +05:30
parent eb40899606
commit a22b6a3444
9 changed files with 294 additions and 8 deletions

View File

@@ -5,5 +5,8 @@ build:
templ generate
go build -o ./bin/pastepass
test:
go test ./...
run:
./bin/pastepass

View File

@@ -63,15 +63,15 @@ func (k *EncryptionKey) Base64Key() string {
return base64.RawURLEncoding.EncodeToString(k.Key)
}
func NewEncryptionKey() *EncryptionKey {
func NewEncryptionKey() (*EncryptionKey, error) {
key := make([]byte, keyLength)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
return nil, err
}
return &EncryptionKey{
Key: key,
}
}, nil
}
func NewEncryptionKeyFromBase64(base64Key string) (*EncryptionKey, error) {

88
db/crypto_test.go Normal file
View File

@@ -0,0 +1,88 @@
package db
import (
"slices"
"testing"
)
func TestEncryptionKey(t *testing.T) {
// Test creating a new key
key, err := NewEncryptionKey()
if err != nil {
t.Fatal(err)
}
if len(key.Key) != 32 {
t.Fatalf("expected key length to be 32, got %d", len(key.Key))
}
if key.Base64Key() == "" {
t.Fatal("expected base64 key to be non-empty")
}
// Test loading key from base64
loadedKey, err := NewEncryptionKeyFromBase64(key.Base64Key())
if err != nil {
t.Fatal(err)
}
if key.Base64Key() != loadedKey.Base64Key() {
t.Fatalf("expected base64 keys to match, got %s and %s", key.Base64Key(), loadedKey.Base64Key())
}
if !slices.Equal(key.Key, loadedKey.Key) {
t.Fatalf("expected keys to match, got %v and %v", key.Key, loadedKey.Key)
}
}
func TestEncryptDecrypt(t *testing.T) {
key, err := NewEncryptionKey()
if err != nil {
t.Fatal(err)
}
plaintext := "hello, world!"
ciphertext, err := key.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
decrypted, err := key.Decrypt(ciphertext)
if err != nil {
t.Fatal(err)
}
if decrypted != plaintext {
t.Fatalf("expected decrypted text to be %s, got %s", plaintext, decrypted)
}
}
func TestEncryptDecryptInvalidKey(t *testing.T) {
key, err := NewEncryptionKey()
if err != nil {
t.Fatal(err)
}
plaintext := "hello, world!"
ciphertext, err := key.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
invalidKey, err := NewEncryptionKey()
if err != nil {
t.Fatal(err)
}
_, err = invalidKey.Decrypt(ciphertext)
if err == nil {
t.Fatal("expected decrypt to fail with invalid key")
}
}
func TestEncryptDecryptInvalidBase64Key(t *testing.T) {
if _, err := NewEncryptionKeyFromBase64("invalid"); err == nil {
t.Fatal("expected loading key from invalid base64 to fail")
}
}

View File

@@ -24,12 +24,22 @@ var (
type DB struct {
boltDB *bolt.DB
path string
}
func (d *DB) Close() error {
return d.boltDB.Close()
}
func (d *DB) reset() error {
if err := d.Close(); err != nil {
return err
}
removeDB(d.path)
return nil
}
func (d *DB) NewPaste(text string, expiresAt time.Time) (*Paste, error) {
paste, err := NewEncryptedPaste(text, expiresAt)
if err != nil {
@@ -98,7 +108,12 @@ func (d *DB) Get(id string) (*Paste, error) {
func (d *DB) Decrypt(id string, key string) (string, error) {
// delete paste if expired
if _, err := d.Get(id); errors.Is(err, ErrPasteExpired) {
return "", d.Delete(id)
// delete paste if expired
if err := d.Delete(id); err != nil {
slog.Error("error_deleting_expired_paste", "id", id, "error", err)
return "", err
}
return "", err
}
var decryptedText string

156
db/db_test.go Normal file
View File

@@ -0,0 +1,156 @@
package db
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/boltdb/bolt"
"github.com/stretchr/testify/assert"
)
func newTestDB() (*DB, error) {
testDbName := fmt.Sprintf(".test.%d.boltdb", rand.Int())
return NewDB(testDbName, true)
}
func TestNewPaste(t *testing.T) {
db, err := newTestDB()
assert.NoError(t, err)
defer db.reset()
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
assert.NoError(t, err)
assert.NotNil(t, paste)
assert.NotEmpty(t, paste.ID)
assert.NotEmpty(t, paste.EncryptedBytes)
_, err = db.Get(paste.ID)
assert.NoError(t, err)
}
func TestGet(t *testing.T) {
db, err := newTestDB()
assert.NoError(t, err)
defer db.reset()
expirationTime := time.Now().Add(time.Hour)
paste, err := db.NewPaste("test paste", expirationTime)
assert.NoError(t, err)
t.Run("returns correct attributes", func(t *testing.T) {
// only metadata is returned for Get
savedPaste, err := db.Get(paste.ID)
assert.NoError(t, err)
assert.Equal(t, paste.ID, savedPaste.ID)
assert.Empty(t, savedPaste.Text)
assert.Empty(t, savedPaste.EncryptedBytes)
assert.Equal(t, expirationTime.Unix(), savedPaste.ExpiresAt.Unix())
})
t.Run("non existent paste", func(t *testing.T) {
// Test nonexistent paste
_, err = db.Get("nonexistent")
assert.Error(t, err)
})
t.Run("expired paste", func(t *testing.T) {
// Test expired paste
paste, err = db.NewPaste("test paste", time.Now().Add(-time.Hour))
assert.NoError(t, err)
_, err = db.Get(paste.ID)
assert.Error(t, err)
})
}
func TestDelete(t *testing.T) {
db, err := newTestDB()
assert.NoError(t, err)
defer db.reset()
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
assert.NoError(t, err)
err = db.Delete(paste.ID)
assert.NoError(t, err)
_, err = db.Get(paste.ID)
assert.Error(t, err)
}
func TestDecrypt(t *testing.T) {
db, err := newTestDB()
assert.NoError(t, err)
defer db.reset()
t.Run("decrypt paste", func(t *testing.T) {
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
assert.NoError(t, err)
// decrypt paste
decryptedText, err := db.Decrypt(paste.ID, paste.Key)
assert.NoError(t, err)
assert.Equal(t, "test paste", decryptedText)
// paste is deleted after decryption
_, err = db.Get(paste.ID)
assert.Error(t, err)
})
t.Run("invalid paste", func(t *testing.T) {
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
assert.NoError(t, err)
// test wrong key
_, err = db.Decrypt(paste.ID, "wrong key")
assert.Error(t, err)
_, err = db.Decrypt("nonexistent", "key")
assert.Error(t, err)
})
t.Run("expired paste", func(t *testing.T) {
paste, err := db.NewPaste("test paste", time.Now().Add(-time.Hour))
assert.NoError(t, err)
decryptedText, err := db.Decrypt(paste.ID, paste.Key)
assert.Empty(t, decryptedText)
assert.Error(t, err)
})
}
func TestDeleteExpired(t *testing.T) {
db, err := newTestDB()
assert.NoError(t, err)
defer db.reset()
_, err = db.NewPaste("test paste", time.Now().Add(time.Hour))
assert.NoError(t, err)
_, err = db.NewPaste("test paste", time.Now().Add(-time.Hour))
assert.NoError(t, err)
err = db.DeleteExpired()
assert.NoError(t, err)
pasteCount := 0
db.boltDB.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(pastesBucketName)
if bucket == nil {
return nil
}
cursor := bucket.Cursor()
for k, _ := cursor.First(); k != nil; k, _ = cursor.Next() {
pasteCount++
}
return nil
})
assert.Equal(t, 1, pasteCount)
}

View File

@@ -14,7 +14,11 @@ type Paste struct {
}
func NewEncryptedPaste(text string, expiresAt time.Time) (*Paste, error) {
key := NewEncryptionKey()
key, err := NewEncryptionKey()
if err != nil {
return nil, err
}
encryptedText, err := key.Encrypt(text)
if err != nil {
return nil, err

View File

@@ -1,10 +1,11 @@
package db
import (
"github.com/boltdb/bolt"
"log/slog"
"os"
"time"
"github.com/boltdb/bolt"
)
func NewDB(path string, reset bool) (*DB, error) {
@@ -17,7 +18,7 @@ func NewDB(path string, reset bool) (*DB, error) {
return nil, err
}
return &DB{boltDB: boltDB}, nil
return &DB{boltDB: boltDB, path: path}, nil
}
func removeDB(path string) {

11
go.mod
View File

@@ -7,4 +7,13 @@ require (
github.com/boltdb/bolt v1.3.1
)
require golang.org/x/sys v0.19.0 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
github.com/stretchr/testify v1.9.0
golang.org/x/sys v0.19.0 // indirect
)

10
go.sum
View File

@@ -2,7 +2,17 @@ github.com/a-h/templ v0.2.707 h1:T1Gkd2ugbRglZ9rYw/VBchWOSZVKmetDbBkm4YubM7U=
github.com/a-h/templ v0.2.707/go.mod h1:5cqsugkq9IerRNucNsI4DEamdHPsoGMQy99DzydLhM8=
github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4=
github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=