diff --git a/Makefile b/Makefile index 7c85ce1..48791a0 100644 --- a/Makefile +++ b/Makefile @@ -5,5 +5,8 @@ build: templ generate go build -o ./bin/pastepass +test: + go test ./... + run: ./bin/pastepass diff --git a/db/crypto.go b/db/crypto.go index b146a4a..472f074 100644 --- a/db/crypto.go +++ b/db/crypto.go @@ -63,15 +63,15 @@ func (k *EncryptionKey) Base64Key() string { return base64.RawURLEncoding.EncodeToString(k.Key) } -func NewEncryptionKey() *EncryptionKey { +func NewEncryptionKey() (*EncryptionKey, error) { key := make([]byte, keyLength) if _, err := io.ReadFull(rand.Reader, key); err != nil { - panic(err) + return nil, err } return &EncryptionKey{ Key: key, - } + }, nil } func NewEncryptionKeyFromBase64(base64Key string) (*EncryptionKey, error) { diff --git a/db/crypto_test.go b/db/crypto_test.go new file mode 100644 index 0000000..3a5c7e4 --- /dev/null +++ b/db/crypto_test.go @@ -0,0 +1,88 @@ +package db + +import ( + "slices" + "testing" +) + +func TestEncryptionKey(t *testing.T) { + // Test creating a new key + key, err := NewEncryptionKey() + if err != nil { + t.Fatal(err) + } + + if len(key.Key) != 32 { + t.Fatalf("expected key length to be 32, got %d", len(key.Key)) + } + + if key.Base64Key() == "" { + t.Fatal("expected base64 key to be non-empty") + } + + // Test loading key from base64 + loadedKey, err := NewEncryptionKeyFromBase64(key.Base64Key()) + if err != nil { + t.Fatal(err) + } + + if key.Base64Key() != loadedKey.Base64Key() { + t.Fatalf("expected base64 keys to match, got %s and %s", key.Base64Key(), loadedKey.Base64Key()) + } + + if !slices.Equal(key.Key, loadedKey.Key) { + t.Fatalf("expected keys to match, got %v and %v", key.Key, loadedKey.Key) + } + +} + +func TestEncryptDecrypt(t *testing.T) { + key, err := NewEncryptionKey() + if err != nil { + t.Fatal(err) + } + + plaintext := "hello, world!" + ciphertext, err := key.Encrypt(plaintext) + if err != nil { + t.Fatal(err) + } + + decrypted, err := key.Decrypt(ciphertext) + if err != nil { + t.Fatal(err) + } + + if decrypted != plaintext { + t.Fatalf("expected decrypted text to be %s, got %s", plaintext, decrypted) + } +} + +func TestEncryptDecryptInvalidKey(t *testing.T) { + key, err := NewEncryptionKey() + if err != nil { + t.Fatal(err) + } + + plaintext := "hello, world!" + ciphertext, err := key.Encrypt(plaintext) + if err != nil { + t.Fatal(err) + } + + invalidKey, err := NewEncryptionKey() + if err != nil { + t.Fatal(err) + } + + _, err = invalidKey.Decrypt(ciphertext) + if err == nil { + t.Fatal("expected decrypt to fail with invalid key") + } +} + +func TestEncryptDecryptInvalidBase64Key(t *testing.T) { + if _, err := NewEncryptionKeyFromBase64("invalid"); err == nil { + t.Fatal("expected loading key from invalid base64 to fail") + } +} diff --git a/db/db.go b/db/db.go index dcb9b36..63bbeba 100644 --- a/db/db.go +++ b/db/db.go @@ -24,12 +24,22 @@ var ( type DB struct { boltDB *bolt.DB + path string } func (d *DB) Close() error { return d.boltDB.Close() } +func (d *DB) reset() error { + if err := d.Close(); err != nil { + return err + } + + removeDB(d.path) + return nil +} + func (d *DB) NewPaste(text string, expiresAt time.Time) (*Paste, error) { paste, err := NewEncryptedPaste(text, expiresAt) if err != nil { @@ -98,7 +108,12 @@ func (d *DB) Get(id string) (*Paste, error) { func (d *DB) Decrypt(id string, key string) (string, error) { // delete paste if expired if _, err := d.Get(id); errors.Is(err, ErrPasteExpired) { - return "", d.Delete(id) + // delete paste if expired + if err := d.Delete(id); err != nil { + slog.Error("error_deleting_expired_paste", "id", id, "error", err) + return "", err + } + return "", err } var decryptedText string diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 0000000..3b26337 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,156 @@ +package db + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/boltdb/bolt" + "github.com/stretchr/testify/assert" +) + +func newTestDB() (*DB, error) { + testDbName := fmt.Sprintf(".test.%d.boltdb", rand.Int()) + return NewDB(testDbName, true) +} + +func TestNewPaste(t *testing.T) { + db, err := newTestDB() + assert.NoError(t, err) + defer db.reset() + + paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour)) + assert.NoError(t, err) + + assert.NotNil(t, paste) + assert.NotEmpty(t, paste.ID) + assert.NotEmpty(t, paste.EncryptedBytes) + + _, err = db.Get(paste.ID) + assert.NoError(t, err) +} + +func TestGet(t *testing.T) { + db, err := newTestDB() + assert.NoError(t, err) + defer db.reset() + + expirationTime := time.Now().Add(time.Hour) + paste, err := db.NewPaste("test paste", expirationTime) + assert.NoError(t, err) + + t.Run("returns correct attributes", func(t *testing.T) { + // only metadata is returned for Get + savedPaste, err := db.Get(paste.ID) + assert.NoError(t, err) + + assert.Equal(t, paste.ID, savedPaste.ID) + assert.Empty(t, savedPaste.Text) + assert.Empty(t, savedPaste.EncryptedBytes) + assert.Equal(t, expirationTime.Unix(), savedPaste.ExpiresAt.Unix()) + }) + + t.Run("non existent paste", func(t *testing.T) { + // Test nonexistent paste + _, err = db.Get("nonexistent") + assert.Error(t, err) + }) + + t.Run("expired paste", func(t *testing.T) { + // Test expired paste + paste, err = db.NewPaste("test paste", time.Now().Add(-time.Hour)) + assert.NoError(t, err) + + _, err = db.Get(paste.ID) + assert.Error(t, err) + }) +} + +func TestDelete(t *testing.T) { + db, err := newTestDB() + assert.NoError(t, err) + defer db.reset() + + paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour)) + assert.NoError(t, err) + + err = db.Delete(paste.ID) + assert.NoError(t, err) + + _, err = db.Get(paste.ID) + assert.Error(t, err) +} + +func TestDecrypt(t *testing.T) { + db, err := newTestDB() + assert.NoError(t, err) + defer db.reset() + + t.Run("decrypt paste", func(t *testing.T) { + paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour)) + assert.NoError(t, err) + + // decrypt paste + decryptedText, err := db.Decrypt(paste.ID, paste.Key) + assert.NoError(t, err) + assert.Equal(t, "test paste", decryptedText) + + // paste is deleted after decryption + _, err = db.Get(paste.ID) + assert.Error(t, err) + }) + + t.Run("invalid paste", func(t *testing.T) { + paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour)) + assert.NoError(t, err) + + // test wrong key + _, err = db.Decrypt(paste.ID, "wrong key") + assert.Error(t, err) + + _, err = db.Decrypt("nonexistent", "key") + assert.Error(t, err) + }) + + t.Run("expired paste", func(t *testing.T) { + paste, err := db.NewPaste("test paste", time.Now().Add(-time.Hour)) + assert.NoError(t, err) + + decryptedText, err := db.Decrypt(paste.ID, paste.Key) + assert.Empty(t, decryptedText) + assert.Error(t, err) + }) +} + +func TestDeleteExpired(t *testing.T) { + db, err := newTestDB() + assert.NoError(t, err) + defer db.reset() + + _, err = db.NewPaste("test paste", time.Now().Add(time.Hour)) + assert.NoError(t, err) + + _, err = db.NewPaste("test paste", time.Now().Add(-time.Hour)) + assert.NoError(t, err) + + err = db.DeleteExpired() + assert.NoError(t, err) + + pasteCount := 0 + db.boltDB.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(pastesBucketName) + if bucket == nil { + return nil + } + + cursor := bucket.Cursor() + for k, _ := cursor.First(); k != nil; k, _ = cursor.Next() { + pasteCount++ + } + + return nil + }) + + assert.Equal(t, 1, pasteCount) +} diff --git a/db/models.go b/db/models.go index 0ef0515..5d618f0 100644 --- a/db/models.go +++ b/db/models.go @@ -14,7 +14,11 @@ type Paste struct { } func NewEncryptedPaste(text string, expiresAt time.Time) (*Paste, error) { - key := NewEncryptionKey() + key, err := NewEncryptionKey() + if err != nil { + return nil, err + } + encryptedText, err := key.Encrypt(text) if err != nil { return nil, err diff --git a/db/utils.go b/db/utils.go index 0aa8dfa..36c6d53 100644 --- a/db/utils.go +++ b/db/utils.go @@ -1,10 +1,11 @@ package db import ( - "github.com/boltdb/bolt" "log/slog" "os" "time" + + "github.com/boltdb/bolt" ) func NewDB(path string, reset bool) (*DB, error) { @@ -17,7 +18,7 @@ func NewDB(path string, reset bool) (*DB, error) { return nil, err } - return &DB{boltDB: boltDB}, nil + return &DB{boltDB: boltDB, path: path}, nil } func removeDB(path string) { diff --git a/go.mod b/go.mod index 3e9fd80..a67623f 100644 --- a/go.mod +++ b/go.mod @@ -7,4 +7,13 @@ require ( github.com/boltdb/bolt v1.3.1 ) -require golang.org/x/sys v0.19.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +require ( + github.com/stretchr/testify v1.9.0 + golang.org/x/sys v0.19.0 // indirect +) diff --git a/go.sum b/go.sum index de68a32..cff2b38 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,17 @@ github.com/a-h/templ v0.2.707 h1:T1Gkd2ugbRglZ9rYw/VBchWOSZVKmetDbBkm4YubM7U= github.com/a-h/templ v0.2.707/go.mod h1:5cqsugkq9IerRNucNsI4DEamdHPsoGMQy99DzydLhM8= github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=