mirror of
https://github.com/v1k45/pastepass.git
synced 2026-01-09 07:33:53 +00:00
Add db and crypto tests; fix minor bugs
This commit is contained in:
3
Makefile
3
Makefile
@@ -5,5 +5,8 @@ build:
|
||||
templ generate
|
||||
go build -o ./bin/pastepass
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
run:
|
||||
./bin/pastepass
|
||||
|
||||
@@ -63,15 +63,15 @@ func (k *EncryptionKey) Base64Key() string {
|
||||
return base64.RawURLEncoding.EncodeToString(k.Key)
|
||||
}
|
||||
|
||||
func NewEncryptionKey() *EncryptionKey {
|
||||
func NewEncryptionKey() (*EncryptionKey, error) {
|
||||
key := make([]byte, keyLength)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
panic(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &EncryptionKey{
|
||||
Key: key,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewEncryptionKeyFromBase64(base64Key string) (*EncryptionKey, error) {
|
||||
|
||||
88
db/crypto_test.go
Normal file
88
db/crypto_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptionKey(t *testing.T) {
|
||||
// Test creating a new key
|
||||
key, err := NewEncryptionKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(key.Key) != 32 {
|
||||
t.Fatalf("expected key length to be 32, got %d", len(key.Key))
|
||||
}
|
||||
|
||||
if key.Base64Key() == "" {
|
||||
t.Fatal("expected base64 key to be non-empty")
|
||||
}
|
||||
|
||||
// Test loading key from base64
|
||||
loadedKey, err := NewEncryptionKeyFromBase64(key.Base64Key())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if key.Base64Key() != loadedKey.Base64Key() {
|
||||
t.Fatalf("expected base64 keys to match, got %s and %s", key.Base64Key(), loadedKey.Base64Key())
|
||||
}
|
||||
|
||||
if !slices.Equal(key.Key, loadedKey.Key) {
|
||||
t.Fatalf("expected keys to match, got %v and %v", key.Key, loadedKey.Key)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key, err := NewEncryptionKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := "hello, world!"
|
||||
ciphertext, err := key.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
decrypted, err := key.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if decrypted != plaintext {
|
||||
t.Fatalf("expected decrypted text to be %s, got %s", plaintext, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptInvalidKey(t *testing.T) {
|
||||
key, err := NewEncryptionKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plaintext := "hello, world!"
|
||||
ciphertext, err := key.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
invalidKey, err := NewEncryptionKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = invalidKey.Decrypt(ciphertext)
|
||||
if err == nil {
|
||||
t.Fatal("expected decrypt to fail with invalid key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptInvalidBase64Key(t *testing.T) {
|
||||
if _, err := NewEncryptionKeyFromBase64("invalid"); err == nil {
|
||||
t.Fatal("expected loading key from invalid base64 to fail")
|
||||
}
|
||||
}
|
||||
17
db/db.go
17
db/db.go
@@ -24,12 +24,22 @@ var (
|
||||
|
||||
type DB struct {
|
||||
boltDB *bolt.DB
|
||||
path string
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.boltDB.Close()
|
||||
}
|
||||
|
||||
func (d *DB) reset() error {
|
||||
if err := d.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
removeDB(d.path)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) NewPaste(text string, expiresAt time.Time) (*Paste, error) {
|
||||
paste, err := NewEncryptedPaste(text, expiresAt)
|
||||
if err != nil {
|
||||
@@ -98,7 +108,12 @@ func (d *DB) Get(id string) (*Paste, error) {
|
||||
func (d *DB) Decrypt(id string, key string) (string, error) {
|
||||
// delete paste if expired
|
||||
if _, err := d.Get(id); errors.Is(err, ErrPasteExpired) {
|
||||
return "", d.Delete(id)
|
||||
// delete paste if expired
|
||||
if err := d.Delete(id); err != nil {
|
||||
slog.Error("error_deleting_expired_paste", "id", id, "error", err)
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var decryptedText string
|
||||
|
||||
156
db/db_test.go
Normal file
156
db/db_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func newTestDB() (*DB, error) {
|
||||
testDbName := fmt.Sprintf(".test.%d.boltdb", rand.Int())
|
||||
return NewDB(testDbName, true)
|
||||
}
|
||||
|
||||
func TestNewPaste(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
defer db.reset()
|
||||
|
||||
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, paste)
|
||||
assert.NotEmpty(t, paste.ID)
|
||||
assert.NotEmpty(t, paste.EncryptedBytes)
|
||||
|
||||
_, err = db.Get(paste.ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
defer db.reset()
|
||||
|
||||
expirationTime := time.Now().Add(time.Hour)
|
||||
paste, err := db.NewPaste("test paste", expirationTime)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("returns correct attributes", func(t *testing.T) {
|
||||
// only metadata is returned for Get
|
||||
savedPaste, err := db.Get(paste.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, paste.ID, savedPaste.ID)
|
||||
assert.Empty(t, savedPaste.Text)
|
||||
assert.Empty(t, savedPaste.EncryptedBytes)
|
||||
assert.Equal(t, expirationTime.Unix(), savedPaste.ExpiresAt.Unix())
|
||||
})
|
||||
|
||||
t.Run("non existent paste", func(t *testing.T) {
|
||||
// Test nonexistent paste
|
||||
_, err = db.Get("nonexistent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("expired paste", func(t *testing.T) {
|
||||
// Test expired paste
|
||||
paste, err = db.NewPaste("test paste", time.Now().Add(-time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Get(paste.ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
defer db.reset()
|
||||
|
||||
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = db.Delete(paste.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Get(paste.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
defer db.reset()
|
||||
|
||||
t.Run("decrypt paste", func(t *testing.T) {
|
||||
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// decrypt paste
|
||||
decryptedText, err := db.Decrypt(paste.ID, paste.Key)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test paste", decryptedText)
|
||||
|
||||
// paste is deleted after decryption
|
||||
_, err = db.Get(paste.ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid paste", func(t *testing.T) {
|
||||
paste, err := db.NewPaste("test paste", time.Now().Add(time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test wrong key
|
||||
_, err = db.Decrypt(paste.ID, "wrong key")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = db.Decrypt("nonexistent", "key")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("expired paste", func(t *testing.T) {
|
||||
paste, err := db.NewPaste("test paste", time.Now().Add(-time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
decryptedText, err := db.Decrypt(paste.ID, paste.Key)
|
||||
assert.Empty(t, decryptedText)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteExpired(t *testing.T) {
|
||||
db, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
defer db.reset()
|
||||
|
||||
_, err = db.NewPaste("test paste", time.Now().Add(time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.NewPaste("test paste", time.Now().Add(-time.Hour))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = db.DeleteExpired()
|
||||
assert.NoError(t, err)
|
||||
|
||||
pasteCount := 0
|
||||
db.boltDB.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(pastesBucketName)
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cursor := bucket.Cursor()
|
||||
for k, _ := cursor.First(); k != nil; k, _ = cursor.Next() {
|
||||
pasteCount++
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, 1, pasteCount)
|
||||
}
|
||||
@@ -14,7 +14,11 @@ type Paste struct {
|
||||
}
|
||||
|
||||
func NewEncryptedPaste(text string, expiresAt time.Time) (*Paste, error) {
|
||||
key := NewEncryptionKey()
|
||||
key, err := NewEncryptionKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedText, err := key.Encrypt(text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/boltdb/bolt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
func NewDB(path string, reset bool) (*DB, error) {
|
||||
@@ -17,7 +18,7 @@ func NewDB(path string, reset bool) (*DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DB{boltDB: boltDB}, nil
|
||||
return &DB{boltDB: boltDB, path: path}, nil
|
||||
}
|
||||
|
||||
func removeDB(path string) {
|
||||
|
||||
11
go.mod
11
go.mod
@@ -7,4 +7,13 @@ require (
|
||||
github.com/boltdb/bolt v1.3.1
|
||||
)
|
||||
|
||||
require golang.org/x/sys v0.19.0 // indirect
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/stretchr/testify v1.9.0
|
||||
golang.org/x/sys v0.19.0 // indirect
|
||||
)
|
||||
|
||||
10
go.sum
10
go.sum
@@ -2,7 +2,17 @@ github.com/a-h/templ v0.2.707 h1:T1Gkd2ugbRglZ9rYw/VBchWOSZVKmetDbBkm4YubM7U=
|
||||
github.com/a-h/templ v0.2.707/go.mod h1:5cqsugkq9IerRNucNsI4DEamdHPsoGMQy99DzydLhM8=
|
||||
github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4=
|
||||
github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
Reference in New Issue
Block a user