diff --git a/cryptor/cryptor.go b/cryptor/cryptor.go index c4c1daa..7d92902 100644 --- a/cryptor/cryptor.go +++ b/cryptor/cryptor.go @@ -284,6 +284,10 @@ func (encrypted *EncryptedData) wrapKey(records *passvault.Records, clearKey []b for _, leftName := range access.LeftNames { for _, rightName := range access.RightNames { + if leftName == rightName { + continue + } + keyBytes, err := encryptKey(leftName, rightName, clearKey) if err != nil { return err diff --git a/cryptor/cryptor_test.go b/cryptor/cryptor_test.go index 011c036..ce7a996 100644 --- a/cryptor/cryptor_test.go +++ b/cryptor/cryptor_test.go @@ -9,6 +9,9 @@ import ( "encoding/base64" "encoding/json" "testing" + + "github.com/cloudflare/redoctober/keycache" + "github.com/cloudflare/redoctober/passvault" ) func TestHash(t *testing.T) { @@ -66,3 +69,55 @@ func TestHash(t *testing.T) { } } + +func TestDuplicates(t *testing.T) { + // Setup total names and partitions. + names := []string{"Alice", "Bob", "Carl"} + recs := make(map[string]passvault.PasswordRecord, 0) + left := []string{"Alice", "Bob"} + right := []string{"Bob", "Carl"} + + // Add each user to the keycache. + cache := keycache.NewCache() + records, err := passvault.InitFrom("memory") + if err != nil { + t.Fatalf("%v", err) + } + + c := Cryptor{&records, &cache} + + for _, name := range names { + pr, err := records.AddNewRecord(name, "weakpassword", true, passvault.DefaultRecordType) + if err != nil { + t.Fatalf("%v", err) + } + + recs[name] = pr + } + + // Create candidate encryption of message. + ac := AccessStructure{ + LeftNames: left, + RightNames: right, + } + + resp, err := c.Encrypt([]byte("Hello World!"), []string{}, ac) + if err != nil { + t.Fatalf("Error: %s", err) + } + + // Delegate one key at a time and check that decryption fails. + for name, pr := range recs { + err = cache.AddKeyFromRecord(pr, name, "weakpassword", nil, nil, 2, "1h") + if err != nil { + t.Fatalf("%v", err) + } + + _, _, _, err := c.Decrypt(resp, name) + if err == nil { + t.Fatalf("That shouldn't have worked!") + } + + cache.FlushCache() + } +}