mirror of
https://github.com/cloudflare/redoctober.git
synced 2026-01-07 05:56:56 +00:00
Merge pull request #118 from cloudflare/brendan/bens-refactor
Import Ben's Refactor
This commit is contained in:
@@ -361,7 +361,7 @@ func (encrypted *EncryptedData) wrapKey(records *passvault.Records, clearKey []b
|
||||
return err
|
||||
}
|
||||
|
||||
db := msp.UserDatabase(UserDatabase{records: records})
|
||||
db := UserDatabase{records: records}
|
||||
shareSet, err := sss.DistributeShares(clearKey, &db)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -458,14 +458,14 @@ func (encrypted *EncryptedData) unwrapKey(cache *keycache.Cache, user string) (u
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
db := msp.UserDatabase(UserDatabase{
|
||||
db := UserDatabase{
|
||||
names: &names,
|
||||
cache: cache,
|
||||
user: user,
|
||||
labels: encrypted.Labels,
|
||||
keySet: encrypted.KeySetRSA,
|
||||
shareSet: encrypted.ShareSet,
|
||||
})
|
||||
}
|
||||
unwrappedKey, err = sss.RecoverSecret(&db)
|
||||
|
||||
return
|
||||
|
||||
@@ -121,18 +121,18 @@ func (f Formatted) String() string {
|
||||
out := fmt.Sprintf("(%v", f.Min)
|
||||
|
||||
for _, cond := range f.Conds {
|
||||
switch cond.(type) {
|
||||
switch cond := cond.(type) {
|
||||
case Name:
|
||||
out += fmt.Sprintf(", %v", cond.(Name).string)
|
||||
out += fmt.Sprintf(", %v", cond.string)
|
||||
case Formatted:
|
||||
out += fmt.Sprintf(", %v", (cond.(Formatted)).String())
|
||||
out += fmt.Sprintf(", %v", cond.String())
|
||||
}
|
||||
}
|
||||
|
||||
return out + ")"
|
||||
}
|
||||
|
||||
func (f Formatted) Ok(db *UserDatabase) bool {
|
||||
func (f Formatted) Ok(db UserDatabase) bool {
|
||||
// Goes through the smallest number of conditions possible to check if the
|
||||
// threshold gate returns true. Sometimes requires recursing down to check
|
||||
// nested threshold gates.
|
||||
@@ -161,9 +161,8 @@ func (f *Formatted) Compress() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch cond.(type) {
|
||||
switch cond := cond.(type) {
|
||||
case Formatted:
|
||||
cond := cond.(Formatted)
|
||||
cond.Compress()
|
||||
f.Conds[i] = cond
|
||||
|
||||
@@ -184,9 +183,8 @@ func (f *Formatted) Compress() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch cond.(type) {
|
||||
switch cond := cond.(type) {
|
||||
case Formatted:
|
||||
cond := cond.(Formatted)
|
||||
cond.Compress()
|
||||
f.Conds[i] = cond
|
||||
|
||||
|
||||
@@ -45,24 +45,24 @@ func TestFormatted(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
db := UserDatabase(Database(map[string][][]byte{
|
||||
db := &Database{
|
||||
"Alice": [][]byte{[]byte("blah")},
|
||||
"Carl": [][]byte{[]byte("herp")},
|
||||
}))
|
||||
}
|
||||
|
||||
if query1.Ok(&db) != true {
|
||||
if query1.Ok(db) != true {
|
||||
t.Fatalf("Query #1 was wrong.")
|
||||
}
|
||||
|
||||
if query2.Ok(&db) != false {
|
||||
if query2.Ok(db) != false {
|
||||
t.Fatalf("Query #2 was wrong.")
|
||||
}
|
||||
|
||||
if query3.Ok(&db) != true {
|
||||
if query3.Ok(db) != true {
|
||||
t.Fatalf("Query #3 was wrong.")
|
||||
}
|
||||
|
||||
if query4.Ok(&db) != false {
|
||||
if query4.Ok(db) != false {
|
||||
t.Fatalf("Query #4 was wrong.")
|
||||
}
|
||||
|
||||
|
||||
58
msp/msp.go
58
msp/msp.go
@@ -17,7 +17,7 @@ type UserDatabase interface {
|
||||
}
|
||||
|
||||
type Condition interface { // Represents one condition in a predicate
|
||||
Ok(*UserDatabase) bool
|
||||
Ok(UserDatabase) bool
|
||||
}
|
||||
|
||||
type Name struct { // Type of condition
|
||||
@@ -25,8 +25,8 @@ type Name struct { // Type of condition
|
||||
index int
|
||||
}
|
||||
|
||||
func (n Name) Ok(db *UserDatabase) bool {
|
||||
return (*db).CanGetShare(n.string)
|
||||
func (n Name) Ok(db UserDatabase) bool {
|
||||
return db.CanGetShare(n.string)
|
||||
}
|
||||
|
||||
type TraceElem struct {
|
||||
@@ -120,22 +120,22 @@ func StringToMSP(pred string) (m MSP, err error) {
|
||||
// names: The names in the top-level threshold gate that need to be delegated.
|
||||
// locs: The index in the treshold gate for each name.
|
||||
// trace: All names that must be delegated for for this gate to be satisfied.
|
||||
func (m MSP) DerivePath(db *UserDatabase) (ok bool, names []string, locs []int, trace []string) {
|
||||
func (m MSP) DerivePath(db UserDatabase) (ok bool, names []string, locs []int, trace []string) {
|
||||
ts := &TraceSlice{}
|
||||
|
||||
for i, cond := range m.Conds {
|
||||
switch cond.(type) {
|
||||
switch cond := cond.(type) {
|
||||
case Name:
|
||||
if (*db).CanGetShare(cond.(Name).string) {
|
||||
if db.CanGetShare(cond.string) {
|
||||
heap.Push(ts, TraceElem{
|
||||
i,
|
||||
[]string{cond.(Name).string},
|
||||
[]string{cond.(Name).string},
|
||||
[]string{cond.string},
|
||||
[]string{cond.string},
|
||||
})
|
||||
}
|
||||
|
||||
case Formatted:
|
||||
sok, _, _, strace := MSP(cond.(Formatted)).DerivePath(db)
|
||||
sok, _, _, strace := MSP(cond).DerivePath(db)
|
||||
if sok {
|
||||
heap.Push(ts, TraceElem{i, []string{}, strace})
|
||||
}
|
||||
@@ -153,7 +153,7 @@ func (m MSP) DerivePath(db *UserDatabase) (ok bool, names []string, locs []int,
|
||||
|
||||
// DistributeShares takes as input a secret and a user database and returns secret shares according to access structure
|
||||
// described by the MSP.
|
||||
func (m MSP) DistributeShares(sec []byte, db *UserDatabase) (map[string][][]byte, error) {
|
||||
func (m MSP) DistributeShares(sec []byte, db UserDatabase) (map[string][][]byte, error) {
|
||||
out := make(map[string][][]byte)
|
||||
|
||||
// Generate a Vandermonde matrix.
|
||||
@@ -187,31 +187,23 @@ func (m MSP) DistributeShares(sec []byte, db *UserDatabase) (map[string][][]byte
|
||||
for i, cond := range m.Conds {
|
||||
share := shares[i]
|
||||
|
||||
switch cond.(type) {
|
||||
switch cond := cond.(type) {
|
||||
case Name:
|
||||
name := cond.(Name).string
|
||||
if _, ok := out[name]; ok {
|
||||
out[name] = append(out[name], share)
|
||||
} else if (*db).ValidUser(name) {
|
||||
out[name] = [][]byte{share}
|
||||
} else {
|
||||
return out, errors.New("Unknown user in predicate.")
|
||||
name := cond.string
|
||||
if !db.ValidUser(name) {
|
||||
return nil, errors.New("Unknown user in predicate.")
|
||||
}
|
||||
|
||||
default:
|
||||
below := MSP(cond.(Formatted))
|
||||
out[name] = append(out[name], share)
|
||||
case Formatted:
|
||||
below := MSP(cond)
|
||||
subOut, err := below.DistributeShares(share, db)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
|
||||
for name, shares := range subOut {
|
||||
if _, ok := out[name]; ok {
|
||||
out[name] = append(out[name], shares...)
|
||||
} else {
|
||||
out[name] = shares
|
||||
}
|
||||
|
||||
out[name] = append(out[name], shares...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,12 +212,12 @@ func (m MSP) DistributeShares(sec []byte, db *UserDatabase) (map[string][][]byte
|
||||
}
|
||||
|
||||
// RecoverSecret takes a user database storing secret shares as input and returns the original secret.
|
||||
func (m MSP) RecoverSecret(db *UserDatabase) ([]byte, error) {
|
||||
func (m MSP) RecoverSecret(db UserDatabase) ([]byte, error) {
|
||||
cache := make(map[string][][]byte, 0) // Caches un-used shares for a user.
|
||||
return m.recoverSecret(db, cache)
|
||||
}
|
||||
|
||||
func (m MSP) recoverSecret(db *UserDatabase, cache map[string][][]byte) ([]byte, error) {
|
||||
func (m MSP) recoverSecret(db UserDatabase, cache map[string][][]byte) ([]byte, error) {
|
||||
var (
|
||||
index = []int{} // Indexes where given shares were in the matrix.
|
||||
shares = []FieldElem{} // Contains shares that will be used in reconstruction.
|
||||
@@ -238,7 +230,7 @@ func (m MSP) recoverSecret(db *UserDatabase, cache map[string][][]byte) ([]byte,
|
||||
|
||||
for _, name := range names {
|
||||
if _, cached := cache[name]; !cached {
|
||||
out, err := (*db).GetShare(name)
|
||||
out, err := db.GetShare(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -251,16 +243,16 @@ func (m MSP) recoverSecret(db *UserDatabase, cache map[string][][]byte) ([]byte,
|
||||
gate := m.Conds[loc]
|
||||
index = append(index, loc+1)
|
||||
|
||||
switch gate.(type) {
|
||||
switch gate := gate.(type) {
|
||||
case Name:
|
||||
if len(cache[gate.(Name).string]) <= gate.(Name).index {
|
||||
if len(cache[gate.string]) <= gate.index {
|
||||
return nil, errors.New("Predicate / database mismatch!")
|
||||
}
|
||||
|
||||
shares = append(shares, FieldElem(cache[gate.(Name).string][gate.(Name).index]))
|
||||
shares = append(shares, FieldElem(cache[gate.string][gate.index]))
|
||||
|
||||
case Formatted:
|
||||
share, err := MSP(gate.(Formatted)).recoverSecret(db, cache)
|
||||
share, err := MSP(gate).recoverSecret(db, cache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -9,18 +9,18 @@ import (
|
||||
|
||||
type Database map[string][][]byte
|
||||
|
||||
func (d Database) ValidUser(name string) bool {
|
||||
_, ok := d[name]
|
||||
func (d *Database) ValidUser(name string) bool {
|
||||
_, ok := (*d)[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d Database) CanGetShare(name string) bool {
|
||||
_, ok := d[name]
|
||||
func (d *Database) CanGetShare(name string) bool {
|
||||
_, ok := (*d)[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d Database) GetShare(name string) ([][]byte, error) {
|
||||
out, ok := d[name]
|
||||
func (d *Database) GetShare(name string) ([][]byte, error) {
|
||||
out, ok := (*d)[name]
|
||||
|
||||
if ok {
|
||||
return out, nil
|
||||
@@ -30,11 +30,11 @@ func (d Database) GetShare(name string) ([][]byte, error) {
|
||||
}
|
||||
|
||||
func TestMSP(t *testing.T) {
|
||||
db := UserDatabase(Database(map[string][][]byte{
|
||||
db := &Database{
|
||||
"Alice": [][]byte{},
|
||||
"Bob": [][]byte{},
|
||||
"Carl": [][]byte{},
|
||||
}))
|
||||
}
|
||||
|
||||
sec := make([]byte, 16)
|
||||
rand.Read(sec)
|
||||
@@ -42,8 +42,8 @@ func TestMSP(t *testing.T) {
|
||||
|
||||
predicate, _ := StringToMSP("(2, (1, Alice, Bob), Carl)")
|
||||
|
||||
shares1, _ := predicate.DistributeShares(sec, &db)
|
||||
shares2, _ := predicate.DistributeShares(sec, &db)
|
||||
shares1, _ := predicate.DistributeShares(sec, db)
|
||||
shares2, _ := predicate.DistributeShares(sec, db)
|
||||
|
||||
alice := bytes.Compare(shares1["Alice"][0], shares2["Alice"][0])
|
||||
bob := bytes.Compare(shares1["Bob"][0], shares2["Bob"][0])
|
||||
@@ -53,8 +53,8 @@ func TestMSP(t *testing.T) {
|
||||
t.Fatalf("Key splitting isn't random! %v %v", shares1, shares2)
|
||||
}
|
||||
|
||||
db1 := UserDatabase(Database(shares1))
|
||||
db2 := UserDatabase(Database(shares2))
|
||||
db1 := Database(shares1)
|
||||
db2 := Database(shares2)
|
||||
|
||||
sec1, err := predicate.RecoverSecret(&db1)
|
||||
if err != nil {
|
||||
|
||||
44
msp/raw.go
44
msp/raw.go
@@ -24,8 +24,8 @@ type Layer struct {
|
||||
type Raw struct { // Represents one node in the tree.
|
||||
NodeType
|
||||
|
||||
Left *Condition
|
||||
Right *Condition
|
||||
Left Condition
|
||||
Right Condition
|
||||
}
|
||||
|
||||
func StringToRaw(r string) (out Raw, err error) {
|
||||
@@ -118,7 +118,7 @@ func StringToRaw(r string) (out Raw, err error) {
|
||||
// Copy left and right out of slice and THEN give a pointer for them!
|
||||
left, right := top.Conditions[i], top.Conditions[i+1] // Legal because of check 2.
|
||||
if oper == typ {
|
||||
built := Raw{typ, &left, &right}
|
||||
built := Raw{typ, left, right}
|
||||
|
||||
top.Conditions = append(
|
||||
top.Conditions[:i],
|
||||
@@ -170,11 +170,11 @@ func StringToRaw(r string) (out Raw, err error) {
|
||||
func (r Raw) String() string {
|
||||
out := ""
|
||||
|
||||
switch (*r.Left).(type) {
|
||||
switch left := r.Left.(type) {
|
||||
case Name:
|
||||
out += (*r.Left).(Name).string
|
||||
default:
|
||||
out += "(" + (*r.Left).(Raw).String() + ")"
|
||||
out += left.string
|
||||
case Raw:
|
||||
out += "(" + left.String() + ")"
|
||||
}
|
||||
|
||||
if r.Type() == NodeAnd {
|
||||
@@ -183,11 +183,11 @@ func (r Raw) String() string {
|
||||
out += " | "
|
||||
}
|
||||
|
||||
switch (*r.Right).(type) {
|
||||
switch right := r.Right.(type) {
|
||||
case Name:
|
||||
out += (*r.Right).(Name).string
|
||||
default:
|
||||
out += "(" + (*r.Right).(Raw).String() + ")"
|
||||
out += right.string
|
||||
case Raw:
|
||||
out += "(" + right.String() + ")"
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -202,28 +202,28 @@ func (r Raw) Formatted() (out Formatted) {
|
||||
out.Min = 1
|
||||
}
|
||||
|
||||
switch (*r.Left).(type) {
|
||||
switch left := r.Left.(type) {
|
||||
case Name:
|
||||
out.Conds = []Condition{(*r.Left).(Name)}
|
||||
default:
|
||||
out.Conds = []Condition{(*r.Left).(Raw).Formatted()}
|
||||
out.Conds = []Condition{left}
|
||||
case Raw:
|
||||
out.Conds = []Condition{left.Formatted()}
|
||||
}
|
||||
|
||||
switch (*r.Right).(type) {
|
||||
switch right := r.Right.(type) {
|
||||
case Name:
|
||||
out.Conds = append(out.Conds, (*r.Right).(Name))
|
||||
default:
|
||||
out.Conds = append(out.Conds, (*r.Right).(Raw).Formatted())
|
||||
out.Conds = append(out.Conds, right)
|
||||
case Raw:
|
||||
out.Conds = append(out.Conds, right.Formatted())
|
||||
}
|
||||
|
||||
out.Compress() // Small amount of predicate compression.
|
||||
return
|
||||
}
|
||||
|
||||
func (r Raw) Ok(db *UserDatabase) bool {
|
||||
func (r Raw) Ok(db UserDatabase) bool {
|
||||
if r.Type() == NodeAnd {
|
||||
return (*r.Left).Ok(db) && (*r.Right).Ok(db)
|
||||
return r.Left.Ok(db) && r.Right.Ok(db)
|
||||
} else {
|
||||
return (*r.Left).Ok(db) || (*r.Right).Ok(db)
|
||||
return r.Left.Ok(db) || r.Right.Ok(db)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,32 +11,32 @@ func TestRaw(t *testing.T) {
|
||||
|
||||
query1 := Raw{
|
||||
NodeType: NodeAnd,
|
||||
Left: &alice,
|
||||
Right: &bob,
|
||||
Left: alice,
|
||||
Right: bob,
|
||||
}
|
||||
|
||||
aliceOrBob := Condition(Raw{
|
||||
NodeType: NodeOr,
|
||||
Left: &alice,
|
||||
Right: &bob,
|
||||
Left: alice,
|
||||
Right: bob,
|
||||
})
|
||||
|
||||
query2 := Raw{
|
||||
NodeType: NodeAnd,
|
||||
Left: &aliceOrBob,
|
||||
Right: &carl,
|
||||
Left: aliceOrBob,
|
||||
Right: carl,
|
||||
}
|
||||
|
||||
db := UserDatabase(Database(map[string][][]byte{
|
||||
db := &Database{
|
||||
"Alice": [][]byte{[]byte("blah")},
|
||||
"Carl": [][]byte{[]byte("herp")},
|
||||
}))
|
||||
}
|
||||
|
||||
if query1.Ok(&db) != false {
|
||||
if query1.Ok(db) != false {
|
||||
t.Fatalf("Query #1 was wrong.")
|
||||
}
|
||||
|
||||
if query2.Ok(&db) != true {
|
||||
if query2.Ok(db) != true {
|
||||
t.Fatalf("Query #2 was wrong.")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user