From 85654fd875b7bd4d116618db3c85ac331585dcc8 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 27 May 2014 22:31:47 -0700 Subject: [PATCH] changed implementation to an immutable AVL+ tree. --- merkle/binary.go | 5 +- merkle/iavl.go | 403 ++++++++++++++++++-------------------------- merkle/iavl_test.go | 314 +++++++++++++++------------------- merkle/types.go | 6 +- merkle/util.go | 42 +++-- 5 files changed, 337 insertions(+), 433 deletions(-) diff --git a/merkle/binary.go b/merkle/binary.go index 6d91c4761..97a951024 100644 --- a/merkle/binary.go +++ b/merkle/binary.go @@ -1,7 +1,8 @@ package merkle const ( - TYPE_BYTE = byte(0x00) + TYPE_NIL = byte(0x00) + TYPE_BYTE = byte(0x01) TYPE_INT8 = byte(0x02) TYPE_UINT8 = byte(0x03) TYPE_INT16 = byte(0x04) @@ -16,6 +17,7 @@ const ( func GetBinaryType(o Binary) byte { switch o.(type) { + case nil: return TYPE_NIL case Byte: return TYPE_BYTE case Int8: return TYPE_INT8 case UInt8: return TYPE_UINT8 @@ -36,6 +38,7 @@ func GetBinaryType(o Binary) byte { func LoadBinary(buf []byte, start int) (Binary, int) { typeByte := buf[start] switch typeByte { + case TYPE_NIL: return nil, start+1 case TYPE_BYTE: return LoadByte(buf[start+1:]), start+2 case TYPE_INT8: return LoadInt8(buf[start+1:]), start+2 case TYPE_UINT8: return LoadUInt8(buf[start+1:]), start+2 diff --git a/merkle/iavl.go b/merkle/iavl.go index 7ee457bba..1500a6302 100644 --- a/merkle/iavl.go +++ b/merkle/iavl.go @@ -4,6 +4,8 @@ import ( "crypto/sha256" ) +const HASH_BYTE_SIZE int = 4+32 + // Immutable AVL Tree (wraps the Node root) type IAVLTree struct { @@ -29,27 +31,36 @@ func (self *IAVLTree) Root() Node { } func (self *IAVLTree) Size() uint64 { + if self.root == nil { return 0 } return self.root.Size() } func (self *IAVLTree) Height() uint8 { + if self.root == nil { return 0 } return self.root.Height() } func (self *IAVLTree) Has(key Key) bool { + if self.root == nil { return false } return self.root.has(self.db, key) } func (self *IAVLTree) Put(key Key, value Value) (updated bool) { + if self.root == nil { + self.root = NewIAVLNode(key, value) + return false + } self.root, updated = self.root.put(self.db, key, value) return updated } func (self *IAVLTree) Hash() (ByteSlice, uint64) { + if self.root == nil { return nil, 0 } return self.root.Hash() } func (self *IAVLTree) Save() { + if self.root == nil { return } if self.root.hash == nil { self.root.Hash() } @@ -57,11 +68,13 @@ func (self *IAVLTree) Save() { } func (self *IAVLTree) Get(key Key) (value Value) { + if self.root == nil { return nil } return self.root.get(self.db, key) } func (self *IAVLTree) Remove(key Key) (value Value, err error) { - newRoot, value, err := self.root.remove(self.db, key) + if self.root == nil { return nil, NotFound(key) } + newRoot, _, value, err := self.root.remove(self.db, key) if err != nil { return nil, err } @@ -69,33 +82,9 @@ func (self *IAVLTree) Remove(key Key) (value Value, err error) { return value, nil } -func (self *IAVLTree) Iterator() NodeIterator { - pop := func (stack []*IAVLNode) ([]*IAVLNode, *IAVLNode) { - if len(stack) <= 0 { - return stack, nil - } else { - return stack[0:len(stack)-1], stack[len(stack)-1] - } - } - - stack := make([]*IAVLNode, 0, 10) - var cur *IAVLNode = self.root - var itr NodeIterator - itr = func()(tn Node) { - if len(stack) > 0 || cur != nil { - for cur != nil { - stack = append(stack, cur) - cur = cur.leftFilled(self.db) - } - stack, cur = pop(stack) - tn = cur - cur = cur.rightFilled(self.db) - return tn - } else { - return nil - } - } - return itr +func (self *IAVLTree) Traverse(cb func(Node) bool) { + if self.root == nil { return } + self.root.traverse(self.db, cb) } @@ -117,19 +106,22 @@ type IAVLNode struct { const ( IAVLNODE_FLAG_PERSISTED = byte(0x01) IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) - - IAVLNODE_DESC_HAS_VALUE = byte(0x01) - IAVLNODE_DESC_HAS_LEFT = byte(0x02) - IAVLNODE_DESC_HAS_RIGHT = byte(0x04) ) +func NewIAVLNode(key Key, value Value) *IAVLNode { + return &IAVLNode{ + key: key, + value: value, + size: 1, + } +} + func (self *IAVLNode) Copy() *IAVLNode { - if self == nil { - return nil + if self.height == 0 { + panic("Why are you copying a value node?") } return &IAVLNode{ key: self.key, - value: self.value, size: self.size, height: self.height, left: self.left, @@ -156,39 +148,45 @@ func (self *IAVLNode) Value() Value { } func (self *IAVLNode) Size() uint64 { - if self == nil { return 0 } return self.size } func (self *IAVLNode) Height() uint8 { - if self == nil { return 0 } return self.height } func (self *IAVLNode) has(db Db, key Key) (has bool) { - if self == nil { return false } if self.key.Equals(key) { return true - } else if key.Less(self.key) { - return self.leftFilled(db).has(db, key) + } + if self.height == 0 { + return false } else { - return self.rightFilled(db).has(db, key) + if key.Less(self.key) { + return self.leftFilled(db).has(db, key) + } else { + return self.rightFilled(db).has(db, key) + } } } func (self *IAVLNode) get(db Db, key Key) (value Value) { - if self == nil { return nil } - if self.key.Equals(key) { - return self.value - } else if key.Less(self.key) { - return self.leftFilled(db).get(db, key) + if self.height == 0 { + if self.key.Equals(key) { + return self.value + } else { + return nil + } } else { - return self.rightFilled(db).get(db, key) + if key.Less(self.key) { + return self.leftFilled(db).get(db, key) + } else { + return self.rightFilled(db).get(db, key) + } } } func (self *IAVLNode) Hash() (ByteSlice, uint64) { - if self == nil { return nil, 0 } if self.hash != nil { return self.hash, 0 } @@ -204,9 +202,7 @@ func (self *IAVLNode) Hash() (ByteSlice, uint64) { } func (self *IAVLNode) Save(db Db) { - if self == nil { - return - } else if self.hash == nil { + if self.hash == nil { panic("savee.hash can't be nil") } if self.flags & IAVLNODE_FLAG_PERSISTED > 0 || @@ -214,119 +210,112 @@ func (self *IAVLNode) Save(db Db) { return } + // children + if self.height > 0 { + self.left.Save(db) + self.right.Save(db) + } + // save self buf := make([]byte, self.ByteSize(), self.ByteSize()) self.SaveTo(buf) db.Put([]byte(self.hash), buf) - // save left - self.left.Save(db) - - // save right - self.right.Save(db) - self.flags |= IAVLNODE_FLAG_PERSISTED } func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { - if self == nil { - return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false - } - - self = self.Copy() - - if self.key.Equals(key) { - self.value = value - return self, true - } - - if key.Less(self.key) { - self.left, updated = self.leftFilled(db).put(db, key, value) + if self.height == 0 { + if key.Less(self.key) { + return &IAVLNode{ + key: self.key, + height: 1, + size: 2, + left: NewIAVLNode(key, value), + right: self, + }, false + } else if self.key.Equals(key) { + return NewIAVLNode(key, value), true + } else { + return &IAVLNode{ + key: key, + height: 1, + size: 2, + left: self, + right: NewIAVLNode(key, value), + }, false + } } else { - self.right, updated = self.rightFilled(db).put(db, key, value) - } - if updated { - return self, updated - } else { - self.calcHeightAndSize(db) - return self.balance(db), updated + self = self.Copy() + if key.Less(self.key) { + self.left, updated = self.leftFilled(db).put(db, key, value) + } else { + self.right, updated = self.rightFilled(db).put(db, key, value) + } + if updated { + return self, updated + } else { + self.calcHeightAndSize(db) + return self.balance(db), updated + } } } -func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, value Value, err error) { - if self == nil { return nil, nil, NotFound(key) } - - if self.key.Equals(key) { - if self.left != nil && self.right != nil { - if self.leftFilled(db).Size() < self.rightFilled(db).Size() { - self, newSelf = self.popNode(db, self.rightFilled(db).lmd(db)) - } else { - self, newSelf = self.popNode(db, self.leftFilled(db).rmd(db)) - } - newSelf.left = self.left - newSelf.right = self.right - newSelf.calcHeightAndSize(db) - return newSelf, self.value, nil - } else if self.left == nil { - return self.rightFilled(db), self.value, nil - } else if self.right == nil { - return self.leftFilled(db), self.value, nil +// newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. +func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) { + if self.height == 0 { + if self.key.Equals(key) { + return nil, nil, self.value, nil } else { - return nil, self.value, nil + return self, nil, nil, NotFound(key) } - } - - if key.Less(self.key) { - if self.left == nil { - return self, nil, NotFound(key) - } - var newLeft *IAVLNode - newLeft, value, err = self.leftFilled(db).remove(db, key) - if newLeft == self.leftFilled(db) { // not found - return self, nil, err - } else if err != nil { // some other error - return self, value, err - } - self = self.Copy() - self.left = newLeft } else { - if self.right == nil { - return self, nil, NotFound(key) + if key.Less(self.key) { + var newLeft *IAVLNode + newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) + if err != nil { + return self, nil, value, err + } else if newLeft == nil { // left node held value, was removed + return self.right, self.key, value, nil + } + self = self.Copy() + self.left = newLeft + } else { + var newRight *IAVLNode + newRight, newKey, value, err = self.rightFilled(db).remove(db, key) + if err != nil { + return self, nil, value, err + } else if newRight == nil { // right node held value, was removed + return self.left, nil, value, nil + } + self = self.Copy() + self.right = newRight + if newKey != nil { + self.key = newKey + newKey = nil + } } - var newRight *IAVLNode - newRight, value, err = self.rightFilled(db).remove(db, key) - if newRight == self.rightFilled(db) { // not found - return self, nil, err - } else if err != nil { // some other error - return self, value, err - } - self = self.Copy() - self.right = newRight + self.calcHeightAndSize(db) + return self.balance(db), newKey, value, err } - self.calcHeightAndSize(db) - return self.balance(db), value, err } func (self *IAVLNode) ByteSize() int { - // 1 byte node descriptor - // 1 byte node neight + // 1 byte node height // 8 bytes node size - size := 10 + size := 9 // key size += 1 // type info size += self.key.ByteSize() - // value - if self.value != nil { + if self.height == 0 { + // value size += 1 // type info - size += self.value.ByteSize() + if self.value != nil { + size += self.value.ByteSize() + } } else { - size += 1 - } - // children - if self.left != nil { + // children size += HASH_BYTE_SIZE - } - if self.right != nil { size += HASH_BYTE_SIZE } return size @@ -341,38 +330,28 @@ func (self *IAVLNode) saveToCountHashes(buf []byte) (int, uint64) { cur := 0 hashCount := uint64(0) - // node descriptor - nodeDesc := byte(0) - if self.value != nil { nodeDesc |= IAVLNODE_DESC_HAS_VALUE } - if self.left != nil { nodeDesc |= IAVLNODE_DESC_HAS_LEFT } - if self.right != nil { nodeDesc |= IAVLNODE_DESC_HAS_RIGHT } - cur += UInt8(nodeDesc).SaveTo(buf[cur:]) - - // node height & size + // height & size cur += UInt8(self.height).SaveTo(buf[cur:]) cur += UInt64(self.size).SaveTo(buf[cur:]) - // node key + // key buf[cur] = GetBinaryType(self.key) cur += 1 cur += self.key.SaveTo(buf[cur:]) - // node value - if self.value != nil { + if self.height == 0 { + // value buf[cur] = GetBinaryType(self.value) cur += 1 - cur += self.value.SaveTo(buf[cur:]) - } - - // left child - if self.left != nil { + if self.value != nil { + cur += self.value.SaveTo(buf[cur:]) + } + } else { + // left leftHash, leftCount := self.left.Hash() hashCount += leftCount cur += leftHash.SaveTo(buf[cur:]) - } - - // right child - if self.right != nil { + // right rightHash, rightCount := self.right.Hash() hashCount += rightCount cur += rightHash.SaveTo(buf[cur:]) @@ -385,51 +364,44 @@ func (self *IAVLNode) saveToCountHashes(buf []byte) (int, uint64) { // load the rest of the data from db. // Not threadsafe. func (self *IAVLNode) fill(db Db) { - if self == nil { - panic("placeholder can't be nil") - } else if self.hash == nil { + if self.hash == nil { panic("placeholder.hash can't be nil") } buf := db.Get(self.hash) cur := 0 // node header - nodeDesc := byte(LoadUInt8(buf)) - self.height = uint8(LoadUInt8(buf[1:])) - self.size = uint64(LoadUInt64(buf[2:])) + self.height = uint8(LoadUInt8(buf[0:])) + self.size = uint64(LoadUInt64(buf[1:])) // key - key, cur := LoadBinary(buf, 10) + key, cur := LoadBinary(buf, 9) self.key = key.(Key) - // value - if nodeDesc & IAVLNODE_DESC_HAS_VALUE > 0 { + + if self.height == 0 { + // value self.value, cur = LoadBinary(buf, cur) - } - // children - if nodeDesc & IAVLNODE_DESC_HAS_LEFT > 0 { + } else { + // left var leftHash ByteSlice leftHash, cur = LoadByteSlice(buf, cur) self.left = &IAVLNode{ hash: leftHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, } - } - if nodeDesc & IAVLNODE_DESC_HAS_RIGHT > 0 { + // right var rightHash ByteSlice rightHash, cur = LoadByteSlice(buf, cur) self.right = &IAVLNode{ hash: rightHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, } - } - if cur != len(buf) { - panic("buf not all consumed") + if cur != len(buf) { + panic("buf not all consumed") + } } self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER } func (self *IAVLNode) leftFilled(db Db) *IAVLNode { - if self.left == nil { - return nil - } if self.left.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { self.left.fill(db) } @@ -437,56 +409,12 @@ func (self *IAVLNode) leftFilled(db Db) *IAVLNode { } func (self *IAVLNode) rightFilled(db Db) *IAVLNode { - if self.right == nil { - return nil - } if self.right.flags & IAVLNODE_FLAG_PLACEHOLDER > 0 { self.right.fill(db) } return self.right } -// Returns a new tree (unless node is the root) & a copy of the popped node. -// Can only pop nodes that have one or no children. -func (self *IAVLNode) popNode(db Db, node *IAVLNode) (newSelf, new_node *IAVLNode) { - if self == nil { - panic("self can't be nil") - } else if node == nil { - panic("node can't be nil") - } else if node.left != nil && node.right != nil { - panic("node hnot have both left and right") - } - - if self == node { - - var n *IAVLNode - if node.left != nil { - n = node.leftFilled(db) - } else if node.right != nil { - n = node.rightFilled(db) - } else { - n = nil - } - node = node.Copy() - node.left = nil - node.right = nil - node.calcHeightAndSize(db) - return n, node - - } else { - - self = self.Copy() - if node.key.Less(self.key) { - self.left, node = self.leftFilled(db).popNode(db, node) - } else { - self.right, node = self.rightFilled(db).popNode(db, node) - } - self.calcHeightAndSize(db) - return self, node - - } -} - func (self *IAVLNode) rotateRight(db Db) *IAVLNode { self = self.Copy() sl := self.leftFilled(db).Copy() @@ -517,13 +445,10 @@ func (self *IAVLNode) rotateLeft(db Db) *IAVLNode { func (self *IAVLNode) calcHeightAndSize(db Db) { self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1 - self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() + 1 + self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() } func (self *IAVLNode) calcBalance(db Db) int { - if self == nil { - return 0 - } return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height()) } @@ -557,20 +482,28 @@ func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) { return self } -func (self *IAVLNode) _md(side func(*IAVLNode)*IAVLNode) (*IAVLNode) { - if self == nil { - return nil - } else if side(self) != nil { - return side(self)._md(side) - } else { +func (self *IAVLNode) lmd(db Db) (*IAVLNode) { + if self.height == 0 { return self } -} - -func (self *IAVLNode) lmd(db Db) (*IAVLNode) { - return self._md(func(node *IAVLNode)*IAVLNode { return node.leftFilled(db) }) + return self.leftFilled(db).lmd(db) } func (self *IAVLNode) rmd(db Db) (*IAVLNode) { - return self._md(func(node *IAVLNode)*IAVLNode { return node.rightFilled(db) }) + if self.height == 0 { + return self + } + return self.rightFilled(db).rmd(db) +} + +func (self *IAVLNode) traverse(db Db, cb func(Node)bool) bool { + stop := cb(self) + if stop { return stop } + if self.height > 0 { + stop = self.leftFilled(db).traverse(db, cb) + if stop { return stop } + stop = self.rightFilled(db).traverse(db, cb) + if stop { return stop } + } + return false } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index d55b10a12..a6d51c0db 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -12,7 +12,6 @@ import ( "crypto/sha256" ) - func init() { if urandom, err := os.Open("/dev/urandom"); err != nil { return @@ -28,137 +27,26 @@ func init() { } } -func TestImmutableAvlPutHasGetRemove(t *testing.T) { - - type record struct { - key String - value String - } - - records := make([]*record, 400) - var node *IAVLNode - var err error - var val Value - var updated bool - - randomRecord := func() *record { - return &record{ randstr(20), randstr(20) } - } - - for i := range records { - r := randomRecord() - records[i] = r - node, updated = node.put(nil, r.key, String("")) - if updated { - t.Error("should have not been updated") - } - node, updated = node.put(nil, r.key, r.value) - if !updated { - t.Error("should have been updated") - } - if node.Size() != uint64(i+1) { - t.Error("size was wrong", node.Size(), i+1) - } - } - - for _, r := range records { - if has := node.has(nil, r.key); !has { - t.Error("Missing key") - } - if has := node.has(nil, randstr(12)); has { - t.Error("Table has extra key") - } - if val := node.get(nil, r.key); !(val.(String)).Equals(r.value) { - t.Error("wrong value") - } - } - - for i, x := range records { - if node, val, err = node.remove(nil, x.key); err != nil { - t.Error(err) - } else if !(val.(String)).Equals(x.value) { - t.Error("wrong value") - } - for _, r := range records[i+1:] { - if has := node.has(nil, r.key); !has { - t.Error("Missing key") - } - if has := node.has(nil, randstr(12)); has { - t.Error("Table has extra key") - } - if val := node.get(nil, r.key); !(val.(String)).Equals(r.value) { - t.Error("wrong value") - } - } - if node.Size() != uint64(len(records) - (i+1)) { - t.Error("size was wrong", node.Size(), (len(records) - (i+1))) - } - } -} - - -func BenchmarkImmutableAvlTree(b *testing.B) { - b.StopTimer() - - type record struct { - key String - value String - } - - randomRecord := func() *record { - return &record{ randstr(32), randstr(32) } - } - - t := NewIAVLTree(nil) - for i:=0; i<1000000; i++ { - r := randomRecord() - t.Put(r.key, r.value) - } - - b.StartTimer() - for i := 0; i < b.N; i++ { - r := randomRecord() - t.Put(r.key, r.value) - t.Remove(r.key) - } -} - - -func TestTraversals(t *testing.T) { - var data []int = []int{ - 1, 5, 7, 9, 12, 13, 17, 18, 19, 20, - } - var order []int = []int{ - 6, 1, 8, 2, 4 , 9 , 5 , 7 , 0 , 3 , - } - - test := func(T Tree) { - t.Logf("%T", T) - for j := range order { - T.Put(Int(data[order[j]]), Int(order[j])) - } - - j := 0 - itr := T.Iterator() - for node := itr(); node != nil; node = itr() { - if int(node.Key().(Int)) != data[j] { - t.Error("key in wrong spot in-order") - } - j += 1 - } - } - test(NewIAVLTree(nil)) -} - -// from http://stackoverflow.com/questions/3955680/how-to-check-if-my-avl-tree-implementation-is-correct -func TestGriffin(t *testing.T) { +func TestUnit(t *testing.T) { // Convenience for a new node - N := func(l *IAVLNode, i int, r *IAVLNode) *IAVLNode { + N := func(l, r interface{}) *IAVLNode { + var left, right *IAVLNode + if _, ok := l.(*IAVLNode); ok { + left = l.(*IAVLNode) + } else { + left = NewIAVLNode(Int32(l.(int)), nil) + } + if _, ok := r.(*IAVLNode); ok { + right = r.(*IAVLNode) + } else { + right = NewIAVLNode(Int32(r.(int)), nil) + } + n := &IAVLNode{ - key: Int32(i), - left: l, - right: r, + key: right.lmd(nil).key, + left: left, + right: right, } n.calcHeightAndSize(nil) n.Hash() @@ -168,14 +56,10 @@ func TestGriffin(t *testing.T) { // Convenience for simple printing of keys & tree structure var P func(*IAVLNode) string P = func(n *IAVLNode) string { - if n.left == nil && n.right == nil { + if n.height == 0 { return fmt.Sprintf("%v", n.key) - } else if n.left == nil { - return fmt.Sprintf("(- %v %v)", n.key, P(n.rightFilled(nil))) - } else if n.right == nil { - return fmt.Sprintf("(%v %v -)", P(n.leftFilled(nil)), n.key) } else { - return fmt.Sprintf("(%v %v %v)", P(n.leftFilled(nil)), n.key, P(n.rightFilled(nil))) + return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) } } @@ -186,12 +70,10 @@ func TestGriffin(t *testing.T) { t.Fatalf("Expected %v new hashes, got %v", hashCount, count) } // nuke hashes and reconstruct hash, ensure it's the same. - itr := (&IAVLTree{root:n2}).Iterator() - for node:=itr(); node!=nil; node = itr() { - if node != nil { - node.(*IAVLNode).hash = nil - } - } + (&IAVLTree{root:n2}).Traverse(func(node Node) bool { + node.(*IAVLNode).hash = nil + return false + }) // ensure that the new hash after nuking is the same as the old. newHash, _ := n2.Hash() if bytes.Compare(hash, newHash) != 0 { @@ -211,7 +93,7 @@ func TestGriffin(t *testing.T) { } expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, value, err := n.remove(nil, Int32(i)) + n2, _, value, err := n.remove(nil, Int32(i)) // ensure node was added & structure is as expected. if value != nil || err != nil || P(n2) != repr { t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", @@ -224,49 +106,108 @@ func TestGriffin(t *testing.T) { //////// Test Put cases: // Case 1: - n1 := N(N(nil, 4, nil), 20, nil) - if P(n1) != "(4 20 -)" { t.Fatalf("Got %v", P(n1)) } + n1 := N(4, 20) - expectPut(n1, 15, "(4 15 20)", 3) - expectPut(n1, 8, "(4 8 20)", 3) + expectPut(n1, 8, "((4 8) 20)", 3) + expectPut(n1, 25, "(4 (20 25))", 3) - // Case 2: - n2 := N(N(N(nil, 3, nil), 4, N(nil, 9, nil)), 20, N(nil, 26, nil)) - if P(n2) != "((3 4 9) 20 26)" { t.Fatalf("Got %v", P(n2)) } + n2 := N(4, N(20, 25)) - expectPut(n2, 15, "((3 4 -) 9 (15 20 26))", 4) - expectPut(n2, 8, "((3 4 8) 9 (- 20 26))", 4) + expectPut(n2, 8, "((4 8) (20 25))", 3) + expectPut(n2, 30, "((4 20) (25 30))", 4) - // Case 2: - n3 := N(N(N(N(nil, 2, nil), 3, nil), 4, N(N(nil, 7, nil), 9, N(nil, 11, nil))), - 20, N(N(nil, 21, nil), 26, N(nil, 30, nil))) - if P(n3) != "(((2 3 -) 4 (7 9 11)) 20 (21 26 30))" { t.Fatalf("Got %v", P(n3)) } + n3 := N(N(1, 2), 6) - expectPut(n3, 15, "(((2 3 -) 4 7) 9 ((- 11 15) 20 (21 26 30)))", 5) - expectPut(n3, 8, "(((2 3 -) 4 (- 7 8)) 9 (11 20 (21 26 30)))", 5) + expectPut(n3, 4, "((1 2) (4 6))", 4) + expectPut(n3, 8, "((1 2) (6 8))", 3) + n4 := N(N(1, 2), N(N(5, 6), N(7, 9))) + + expectPut(n4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) + expectPut(n4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) //////// Test Remove cases: - // Case 4: - n4 := N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil))) - if P(n4) != "(1 2 (3 4 5))" { t.Fatalf("Got %v", P(n4)) } + n10 := N(N(1, 2), 3) - expectRemove(n4, 1, "((- 2 3) 4 5)", 2) + expectRemove(n10, 2, "(1 3)", 1) + expectRemove(n10, 3, "(1 2)", 0) - // Case 5: - n5 := N(N(N(nil, 1, nil), 2, N(N(nil, 3, nil), 4, N(nil, 5, nil))), 6, - N(N(N(nil, 7, nil), 8, nil), 9, N(N(nil, 10, nil), 11, N(nil, 12, N(nil, 13, nil))))) - if P(n5) != "((1 2 (3 4 5)) 6 ((7 8 -) 9 (10 11 (- 12 13))))" { t.Fatalf("Got %v", P(n5)) } + n11 := N(N(N(1, 2), 3), N(4, 5)) - expectRemove(n5, 1, "(((- 2 3) 4 5) 6 ((7 8 -) 9 (10 11 (- 12 13))))", 3) + expectRemove(n11, 4, "((1 2) (3 5))", 2) + expectRemove(n11, 3, "((1 2) (4 5))", 1) - // Case 6: - n6 := N(N(N(nil, 1, nil), 2, N(nil, 3, N(nil, 4, nil))), 5, - N(N(N(nil, 6, nil), 7, nil), 8, N(N(nil, 9, nil), 10, N(nil, 11, N(nil, 12, nil))))) - if P(n6) != "((1 2 (- 3 4)) 5 ((6 7 -) 8 (9 10 (- 11 12))))" { t.Fatalf("Got %v", P(n6)) } +} - expectRemove(n6, 1, "(((2 3 4) 5 (6 7 -)) 8 (9 10 (- 11 12)))", 4) +func TestIntegration(t *testing.T) { + + type record struct { + key String + value String + } + + records := make([]*record, 400) + var tree *IAVLTree = NewIAVLTree(nil) + var err error + var val Value + var updated bool + + randomRecord := func() *record { + return &record{ RandStr(20), RandStr(20) } + } + + for i := range records { + r := randomRecord() + records[i] = r + //t.Log("New record", r) + //PrintIAVLNode(tree.root) + updated = tree.Put(r.key, String("")) + if updated { + t.Error("should have not been updated") + } + updated = tree.Put(r.key, r.value) + if !updated { + t.Error("should have been updated") + } + if tree.Size() != uint64(i+1) { + t.Error("size was wrong", tree.Size(), i+1) + } + } + + for _, r := range records { + if has := tree.Has(r.key); !has { + t.Error("Missing key", r.key) + } + if has := tree.Has(RandStr(12)); has { + t.Error("Table has extra key") + } + if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } + + for i, x := range records { + if val, err = tree.Remove(x.key); err != nil { + t.Error(err) + } else if !(val.(String)).Equals(x.value) { + t.Error("wrong value") + } + for _, r := range records[i+1:] { + if has := tree.Has(r.key); !has { + t.Error("Missing key", r.key) + } + if has := tree.Has(RandStr(12)); has { + t.Error("Table has extra key") + } + if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + t.Error("wrong value") + } + } + if tree.Size() != uint64(len(records) - (i+1)) { + t.Error("size was wrong", tree.Size(), (len(records) - (i+1))) + } + } } func TestPersistence(t *testing.T) { @@ -275,7 +216,7 @@ func TestPersistence(t *testing.T) { // Create some random key value pairs records := make(map[String]String) for i:=0; i<10000; i++ { - records[String(randstr(20))] = String(randstr(20)) + records[String(RandStr(20))] = String(RandStr(20)) } // Construct some tree and save it @@ -295,13 +236,12 @@ func TestPersistence(t *testing.T) { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } } - } func BenchmarkHash(b *testing.B) { b.StopTimer() - s := randstr(128) + s := RandStr(128) b.StartTimer() for i := 0; i < b.N; i++ { @@ -310,3 +250,29 @@ func BenchmarkHash(b *testing.B) { hasher.Sum(nil) } } + +func BenchmarkImmutableAvlTree(b *testing.B) { + b.StopTimer() + + type record struct { + key String + value String + } + + randomRecord := func() *record { + return &record{ RandStr(32), RandStr(32) } + } + + t := NewIAVLTree(nil) + for i:=0; i<1000000; i++ { + r := randomRecord() + t.Put(r.key, r.value) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + r := randomRecord() + t.Put(r.key, r.value) + t.Remove(r.key) + } +} diff --git a/merkle/types.go b/merkle/types.go index f096774a6..6d46b2dcb 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -4,8 +4,6 @@ import ( "fmt" ) -const HASH_BYTE_SIZE int = 4+32 - type Binary interface { ByteSize() int SaveTo([]byte) int @@ -36,8 +34,6 @@ type Node interface { Save(Db) } -type NodeIterator func() Node - type Tree interface { Root() Node Size() uint64 @@ -48,7 +44,7 @@ type Tree interface { Save() Put(Key, Value) bool Remove(Key) (Value, error) - Iterator() NodeIterator + Traverse(func(Node)bool) } func NotFound(key Key) error { diff --git a/merkle/util.go b/merkle/util.go index cc37f2578..95badde06 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -1,13 +1,16 @@ package merkle import ( - "os" + "math/big" + "crypto/rand" "fmt" ) func PrintIAVLNode(node *IAVLNode) { fmt.Println("==== NODE") - printIAVLNode(node, 0) + if node != nil { + printIAVLNode(node, 0) + } fmt.Println("==== END") } @@ -16,27 +19,30 @@ func printIAVLNode(node *IAVLNode, indent int) { for i:=0; i