diff --git a/CHANGELOG.md b/CHANGELOG.md index b679b839d..fe2c2fe94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## 0.6.0 (December 29, 2017) + +BREAKING: + - [cli] remove --root + - [pubsub] add String() method to Query interface + +IMPROVEMENTS: + - [common] use a thread-safe and well seeded non-crypto rng + +BUG FIXES + - [clist] fix misuse of wait group + - [common] introduce Ticker interface and logicalTicker for better testing of timers + ## 0.5.0 (December 5, 2017) BREAKING: diff --git a/cli/setup.go b/cli/setup.go index 78151015b..295477598 100644 --- a/cli/setup.go +++ b/cli/setup.go @@ -14,7 +14,6 @@ import ( ) const ( - RootFlag = "root" HomeFlag = "home" TraceFlag = "trace" OutputFlag = "output" @@ -28,14 +27,9 @@ type Executable interface { } // PrepareBaseCmd is meant for tendermint and other servers -func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor { +func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor { cobra.OnInitialize(func() { initEnv(envPrefix) }) - cmd.PersistentFlags().StringP(RootFlag, "r", defautRoot, "DEPRECATED. Use --home") - // -h is already reserved for --help as part of the cobra framework - // do you want to try something else?? - // also, default must be empty, so we can detect this unset and fall back - // to --root / TM_ROOT / TMROOT - cmd.PersistentFlags().String(HomeFlag, "", "root directory for config and data") + cmd.PersistentFlags().StringP(HomeFlag, "", defaultHome, "directory for config and data") cmd.PersistentFlags().Bool(TraceFlag, false, "print out full stack trace on errors") cmd.PersistentPreRunE = concatCobraCmdFuncs(bindFlagsLoadViper, cmd.PersistentPreRunE) return Executor{cmd, os.Exit} @@ -45,11 +39,11 @@ func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor { // // This adds --encoding (hex, btc, base64) and --output (text, json) to // the command. These only really make sense in interactive commands. -func PrepareMainCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor { +func PrepareMainCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor { cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)") cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)") cmd.PersistentPreRunE = concatCobraCmdFuncs(setEncoding, validateOutput, cmd.PersistentPreRunE) - return PrepareBaseCmd(cmd, envPrefix, defautRoot) + return PrepareBaseCmd(cmd, envPrefix, defaultHome) } // initEnv sets to use ENV variables if set. @@ -136,17 +130,10 @@ func bindFlagsLoadViper(cmd *cobra.Command, args []string) error { return err } - // rootDir is command line flag, env variable, or default $HOME/.tlc - // NOTE: we support both --root and --home for now, but eventually only --home - // Also ensure we set the correct rootDir under HomeFlag so we dont need to - // repeat this logic elsewhere. - rootDir := viper.GetString(HomeFlag) - if rootDir == "" { - rootDir = viper.GetString(RootFlag) - viper.Set(HomeFlag, rootDir) - } + homeDir := viper.GetString(HomeFlag) + viper.Set(HomeFlag, homeDir) viper.SetConfigName("config") // name of config file (without extension) - viper.AddConfigPath(rootDir) // search root directory + viper.AddConfigPath(homeDir) // search root directory // If a config file is found, read it in. if err := viper.ReadInConfig(); err == nil { diff --git a/cli/setup_test.go b/cli/setup_test.go index 692da26d3..e0fd75d8a 100644 --- a/cli/setup_test.go +++ b/cli/setup_test.go @@ -57,12 +57,9 @@ func TestSetupEnv(t *testing.T) { func TestSetupConfig(t *testing.T) { // we pre-create two config files we can refer to in the rest of // the test cases. - cval1, cval2 := "fubble", "wubble" + cval1 := "fubble" conf1, err := WriteDemoConfig(map[string]string{"boo": cval1}) require.Nil(t, err) - // make sure it handles dashed-words in the config, and ignores random info - conf2, err := WriteDemoConfig(map[string]string{"boo": cval2, "foo": "bar", "two-words": "WORD"}) - require.Nil(t, err) cases := []struct { args []string @@ -74,16 +71,13 @@ func TestSetupConfig(t *testing.T) { // setting on the command line {[]string{"--boo", "haha"}, nil, "haha", ""}, {[]string{"--two-words", "rocks"}, nil, "", "rocks"}, - {[]string{"--root", conf1}, nil, cval1, ""}, + {[]string{"--home", conf1}, nil, cval1, ""}, // test both variants of the prefix {nil, map[string]string{"RD_BOO": "bang"}, "bang", ""}, {nil, map[string]string{"RD_TWO_WORDS": "fly"}, "", "fly"}, {nil, map[string]string{"RDTWO_WORDS": "fly"}, "", "fly"}, - {nil, map[string]string{"RD_ROOT": conf1}, cval1, ""}, - {nil, map[string]string{"RDROOT": conf2}, cval2, "WORD"}, + {nil, map[string]string{"RD_HOME": conf1}, cval1, ""}, {nil, map[string]string{"RDHOME": conf1}, cval1, ""}, - // and when both are set??? HOME wins every time! - {[]string{"--root", conf1}, map[string]string{"RDHOME": conf2}, cval2, "WORD"}, } for idx, tc := range cases { @@ -156,10 +150,10 @@ func TestSetupUnmarshal(t *testing.T) { {nil, nil, c("", 0)}, // setting on the command line {[]string{"--name", "haha"}, nil, c("haha", 0)}, - {[]string{"--root", conf1}, nil, c(cval1, 0)}, + {[]string{"--home", conf1}, nil, c(cval1, 0)}, // test both variants of the prefix {nil, map[string]string{"MR_AGE": "56"}, c("", 56)}, - {nil, map[string]string{"MR_ROOT": conf1}, c(cval1, 0)}, + {nil, map[string]string{"MR_HOME": conf1}, c(cval1, 0)}, {[]string{"--age", "17"}, map[string]string{"MRHOME": conf2}, c(cval2, 17)}, } diff --git a/clist/clist.go b/clist/clist.go index 5295dd995..a52920f8c 100644 --- a/clist/clist.go +++ b/clist/clist.go @@ -1,46 +1,68 @@ package clist /* + The purpose of CList is to provide a goroutine-safe linked-list. This list can be traversed concurrently by any number of goroutines. However, removed CElements cannot be added back. NOTE: Not all methods of container/list are (yet) implemented. NOTE: Removed elements need to DetachPrev or DetachNext consistently to ensure garbage collection of removed elements. + */ import ( "sync" - "sync/atomic" - "unsafe" ) -// CElement is an element of a linked-list -// Traversal from a CElement are goroutine-safe. +/* + +CElement is an element of a linked-list +Traversal from a CElement is goroutine-safe. + +We can't avoid using WaitGroups or for-loops given the documentation +spec without re-implementing the primitives that already exist in +golang/sync. Notice that WaitGroup allows many go-routines to be +simultaneously released, which is what we want. Mutex doesn't do +this. RWMutex does this, but it's clumsy to use in the way that a +WaitGroup would be used -- and we'd end up having two RWMutex's for +prev/next each, which is doubly confusing. + +sync.Cond would be sort-of useful, but we don't need a write-lock in +the for-loop. Use sync.Cond when you need serial access to the +"condition". In our case our condition is if `next != nil || removed`, +and there's no reason to serialize that condition for goroutines +waiting on NextWait() (since it's just a read operation). + +*/ type CElement struct { - prev unsafe.Pointer + mtx sync.RWMutex + prev *CElement prevWg *sync.WaitGroup - next unsafe.Pointer + next *CElement nextWg *sync.WaitGroup - removed uint32 - Value interface{} + removed bool + + Value interface{} // immutable } // Blocking implementation of Next(). // May return nil iff CElement was tail and got removed. func (e *CElement) NextWait() *CElement { for { - e.nextWg.Wait() - next := e.Next() - if next == nil { - if e.Removed() { - return nil - } else { - continue - } - } else { + e.mtx.RLock() + next := e.next + nextWg := e.nextWg + removed := e.removed + e.mtx.RUnlock() + + if next != nil || removed { return next } + + nextWg.Wait() + // e.next doesn't necessarily exist here. + // That's why we need to continue a for-loop. } } @@ -48,82 +70,113 @@ func (e *CElement) NextWait() *CElement { // May return nil iff CElement was head and got removed. func (e *CElement) PrevWait() *CElement { for { - e.prevWg.Wait() - prev := e.Prev() - if prev == nil { - if e.Removed() { - return nil - } else { - continue - } - } else { + e.mtx.RLock() + prev := e.prev + prevWg := e.prevWg + removed := e.removed + e.mtx.RUnlock() + + if prev != nil || removed { return prev } + + prevWg.Wait() } } // Nonblocking, may return nil if at the end. func (e *CElement) Next() *CElement { - return (*CElement)(atomic.LoadPointer(&e.next)) + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.next } // Nonblocking, may return nil if at the end. func (e *CElement) Prev() *CElement { - return (*CElement)(atomic.LoadPointer(&e.prev)) + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.prev } func (e *CElement) Removed() bool { - return atomic.LoadUint32(&(e.removed)) > 0 + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.removed } func (e *CElement) DetachNext() { if !e.Removed() { panic("DetachNext() must be called after Remove(e)") } - atomic.StorePointer(&e.next, nil) + e.mtx.Lock() + defer e.mtx.Unlock() + + e.next = nil } func (e *CElement) DetachPrev() { if !e.Removed() { panic("DetachPrev() must be called after Remove(e)") } - atomic.StorePointer(&e.prev, nil) + e.mtx.Lock() + defer e.mtx.Unlock() + + e.prev = nil } -func (e *CElement) setNextAtomic(next *CElement) { - for { - oldNext := atomic.LoadPointer(&e.next) - if !atomic.CompareAndSwapPointer(&(e.next), oldNext, unsafe.Pointer(next)) { - continue - } - if next == nil && oldNext != nil { // We for-loop in NextWait() so race is ok - e.nextWg.Add(1) - } - if next != nil && oldNext == nil { - e.nextWg.Done() - } - return +// NOTE: This function needs to be safe for +// concurrent goroutines waiting on nextWg. +func (e *CElement) SetNext(newNext *CElement) { + e.mtx.Lock() + defer e.mtx.Unlock() + + oldNext := e.next + e.next = newNext + if oldNext != nil && newNext == nil { + // See https://golang.org/pkg/sync/: + // + // If a WaitGroup is reused to wait for several independent sets of + // events, new Add calls must happen after all previous Wait calls have + // returned. + e.nextWg = waitGroup1() // WaitGroups are difficult to re-use. + } + if oldNext == nil && newNext != nil { + e.nextWg.Done() } } -func (e *CElement) setPrevAtomic(prev *CElement) { - for { - oldPrev := atomic.LoadPointer(&e.prev) - if !atomic.CompareAndSwapPointer(&(e.prev), oldPrev, unsafe.Pointer(prev)) { - continue - } - if prev == nil && oldPrev != nil { // We for-loop in PrevWait() so race is ok - e.prevWg.Add(1) - } - if prev != nil && oldPrev == nil { - e.prevWg.Done() - } - return +// NOTE: This function needs to be safe for +// concurrent goroutines waiting on prevWg +func (e *CElement) SetPrev(newPrev *CElement) { + e.mtx.Lock() + defer e.mtx.Unlock() + + oldPrev := e.prev + e.prev = newPrev + if oldPrev != nil && newPrev == nil { + e.prevWg = waitGroup1() // WaitGroups are difficult to re-use. + } + if oldPrev == nil && newPrev != nil { + e.prevWg.Done() } } -func (e *CElement) setRemovedAtomic() { - atomic.StoreUint32(&(e.removed), 1) +func (e *CElement) SetRemoved() { + e.mtx.Lock() + defer e.mtx.Unlock() + + e.removed = true + + // This wakes up anyone waiting in either direction. + if e.prev == nil { + e.prevWg.Done() + } + if e.next == nil { + e.nextWg.Done() + } } //-------------------------------------------------------------------------------- @@ -132,7 +185,7 @@ func (e *CElement) setRemovedAtomic() { // The zero value for CList is an empty list ready to use. // Operations are goroutine-safe. type CList struct { - mtx sync.Mutex + mtx sync.RWMutex wg *sync.WaitGroup head *CElement // first element tail *CElement // last element @@ -142,6 +195,7 @@ type CList struct { func (l *CList) Init() *CList { l.mtx.Lock() defer l.mtx.Unlock() + l.wg = waitGroup1() l.head = nil l.tail = nil @@ -152,48 +206,55 @@ func (l *CList) Init() *CList { func New() *CList { return new(CList).Init() } func (l *CList) Len() int { - l.mtx.Lock() - defer l.mtx.Unlock() + l.mtx.RLock() + defer l.mtx.RUnlock() + return l.len } func (l *CList) Front() *CElement { - l.mtx.Lock() - defer l.mtx.Unlock() + l.mtx.RLock() + defer l.mtx.RUnlock() + return l.head } func (l *CList) FrontWait() *CElement { + // Loop until the head is non-nil else wait and try again for { - l.mtx.Lock() + l.mtx.RLock() head := l.head wg := l.wg - l.mtx.Unlock() - if head == nil { - wg.Wait() - } else { + l.mtx.RUnlock() + + if head != nil { return head } + wg.Wait() + // NOTE: If you think l.head exists here, think harder. } } func (l *CList) Back() *CElement { - l.mtx.Lock() - defer l.mtx.Unlock() + l.mtx.RLock() + defer l.mtx.RUnlock() + return l.tail } func (l *CList) BackWait() *CElement { for { - l.mtx.Lock() + l.mtx.RLock() tail := l.tail wg := l.wg - l.mtx.Unlock() - if tail == nil { - wg.Wait() - } else { + l.mtx.RUnlock() + + if tail != nil { return tail } + wg.Wait() + // l.tail doesn't necessarily exist here. + // That's why we need to continue a for-loop. } } @@ -203,11 +264,12 @@ func (l *CList) PushBack(v interface{}) *CElement { // Construct a new element e := &CElement{ - prev: nil, - prevWg: waitGroup1(), - next: nil, - nextWg: waitGroup1(), - Value: v, + prev: nil, + prevWg: waitGroup1(), + next: nil, + nextWg: waitGroup1(), + removed: false, + Value: v, } // Release waiters on FrontWait/BackWait maybe @@ -221,9 +283,9 @@ func (l *CList) PushBack(v interface{}) *CElement { l.head = e l.tail = e } else { - l.tail.setNextAtomic(e) - e.setPrevAtomic(l.tail) - l.tail = e + e.SetPrev(l.tail) // We must init e first. + l.tail.SetNext(e) // This will make e accessible. + l.tail = e // Update the list. } return e @@ -250,30 +312,26 @@ func (l *CList) Remove(e *CElement) interface{} { // If we're removing the only item, make CList FrontWait/BackWait wait. if l.len == 1 { - l.wg.Add(1) + l.wg = waitGroup1() // WaitGroups are difficult to re-use. } + + // Update l.len l.len -= 1 // Connect next/prev and set head/tail if prev == nil { l.head = next } else { - prev.setNextAtomic(next) + prev.SetNext(next) } if next == nil { l.tail = prev } else { - next.setPrevAtomic(prev) + next.SetPrev(prev) } // Set .Done() on e, otherwise waiters will wait forever. - e.setRemovedAtomic() - if prev == nil { - e.prevWg.Done() - } - if next == nil { - e.nextWg.Done() - } + e.SetRemoved() return e.Value } diff --git a/common/bit_array.go b/common/bit_array.go index 5590fe61b..848763b48 100644 --- a/common/bit_array.go +++ b/common/bit_array.go @@ -3,7 +3,6 @@ package common import ( "encoding/binary" "fmt" - "math/rand" "strings" "sync" ) @@ -212,12 +211,12 @@ func (bA *BitArray) PickRandom() (int, bool) { if length == 0 { return 0, false } - randElemStart := rand.Intn(length) + randElemStart := RandIntn(length) for i := 0; i < length; i++ { elemIdx := ((i + randElemStart) % length) if elemIdx < length-1 { if bA.Elems[elemIdx] > 0 { - randBitStart := rand.Intn(64) + randBitStart := RandIntn(64) for j := 0; j < 64; j++ { bitIdx := ((j + randBitStart) % 64) if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { @@ -232,7 +231,7 @@ func (bA *BitArray) PickRandom() (int, bool) { if elemBits == 0 { elemBits = 64 } - randBitStart := rand.Intn(elemBits) + randBitStart := RandIntn(elemBits) for j := 0; j < elemBits; j++ { bitIdx := ((j + randBitStart) % elemBits) if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { diff --git a/common/random.go b/common/random.go index 73bd16356..ca71b6143 100644 --- a/common/random.go +++ b/common/random.go @@ -2,7 +2,8 @@ package common import ( crand "crypto/rand" - "math/rand" + mrand "math/rand" + "sync" "time" ) @@ -10,22 +11,36 @@ const ( strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters ) -func init() { +// pseudo random number generator. +// seeded with OS randomness (crand) +var prng struct { + sync.Mutex + *mrand.Rand +} + +func reset() { b := cRandBytes(8) var seed uint64 for i := 0; i < 8; i++ { seed |= uint64(b[i]) seed <<= 8 } - rand.Seed(int64(seed)) + prng.Lock() + prng.Rand = mrand.New(mrand.NewSource(int64(seed))) + prng.Unlock() +} + +func init() { + reset() } // Constructs an alphanumeric string of given length. +// It is not safe for cryptographic usage. func RandStr(length int) string { chars := []byte{} MAIN_LOOP: for { - val := rand.Int63() + val := RandInt63() for i := 0; i < 10; i++ { v := int(val & 0x3f) // rightmost 6 bits if v >= 62 { // only 62 characters in strChars @@ -44,87 +59,151 @@ MAIN_LOOP: return string(chars) } +// It is not safe for cryptographic usage. func RandUint16() uint16 { - return uint16(rand.Uint32() & (1<<16 - 1)) + return uint16(RandUint32() & (1<<16 - 1)) } +// It is not safe for cryptographic usage. func RandUint32() uint32 { - return rand.Uint32() + prng.Lock() + u32 := prng.Uint32() + prng.Unlock() + return u32 } +// It is not safe for cryptographic usage. func RandUint64() uint64 { - return uint64(rand.Uint32())<<32 + uint64(rand.Uint32()) + return uint64(RandUint32())<<32 + uint64(RandUint32()) } +// It is not safe for cryptographic usage. func RandUint() uint { - return uint(rand.Int()) + prng.Lock() + i := prng.Int() + prng.Unlock() + return uint(i) } +// It is not safe for cryptographic usage. func RandInt16() int16 { - return int16(rand.Uint32() & (1<<16 - 1)) + return int16(RandUint32() & (1<<16 - 1)) } +// It is not safe for cryptographic usage. func RandInt32() int32 { - return int32(rand.Uint32()) + return int32(RandUint32()) } +// It is not safe for cryptographic usage. func RandInt64() int64 { - return int64(rand.Uint32())<<32 + int64(rand.Uint32()) + return int64(RandUint64()) } +// It is not safe for cryptographic usage. func RandInt() int { - return rand.Int() + prng.Lock() + i := prng.Int() + prng.Unlock() + return i +} + +// It is not safe for cryptographic usage. +func RandInt31() int32 { + prng.Lock() + i31 := prng.Int31() + prng.Unlock() + return i31 +} + +// It is not safe for cryptographic usage. +func RandInt63() int64 { + prng.Lock() + i63 := prng.Int63() + prng.Unlock() + return i63 } // Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. func RandUint16Exp() uint16 { - bits := rand.Uint32() % 16 + bits := RandUint32() % 16 if bits == 0 { return 0 } n := uint16(1 << (bits - 1)) - n += uint16(rand.Int31()) & ((1 << (bits - 1)) - 1) + n += uint16(RandInt31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. func RandUint32Exp() uint32 { - bits := rand.Uint32() % 32 + bits := RandUint32() % 32 if bits == 0 { return 0 } n := uint32(1 << (bits - 1)) - n += uint32(rand.Int31()) & ((1 << (bits - 1)) - 1) + n += uint32(RandInt31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. func RandUint64Exp() uint64 { - bits := rand.Uint32() % 64 + bits := RandUint32() % 64 if bits == 0 { return 0 } n := uint64(1 << (bits - 1)) - n += uint64(rand.Int63()) & ((1 << (bits - 1)) - 1) + n += uint64(RandInt63()) & ((1 << (bits - 1)) - 1) return n } +// It is not safe for cryptographic usage. func RandFloat32() float32 { - return rand.Float32() + prng.Lock() + f32 := prng.Float32() + prng.Unlock() + return f32 } +// It is not safe for cryptographic usage. func RandTime() time.Time { return time.Unix(int64(RandUint64Exp()), 0) } +// RandBytes returns n random bytes from the OS's source of entropy ie. via crypto/rand. +// It is not safe for cryptographic usage. func RandBytes(n int) []byte { + // cRandBytes isn't guaranteed to be fast so instead + // use random bytes generated from the internal PRNG bs := make([]byte, n) - for i := 0; i < n; i++ { - bs[i] = byte(rand.Intn(256)) + for i := 0; i < len(bs); i++ { + bs[i] = byte(RandInt() & 0xFF) } return bs } +// RandIntn returns, as an int, a non-negative pseudo-random number in [0, n). +// It panics if n <= 0. +// It is not safe for cryptographic usage. +func RandIntn(n int) int { + prng.Lock() + i := prng.Intn(n) + prng.Unlock() + return i +} + +// RandPerm returns a pseudo-random permutation of n integers in [0, n). +// It is not safe for cryptographic usage. +func RandPerm(n int) []int { + prng.Lock() + perm := prng.Perm(n) + prng.Unlock() + return perm +} + // NOTE: This relies on the os's random number generator. // For real security, we should salt that with some seed. // See github.com/tendermint/go-crypto for a more secure reader. diff --git a/common/random_test.go b/common/random_test.go new file mode 100644 index 000000000..216f2f8bc --- /dev/null +++ b/common/random_test.go @@ -0,0 +1,120 @@ +package common + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + mrand "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRandStr(t *testing.T) { + l := 243 + s := RandStr(l) + assert.Equal(t, l, len(s)) +} + +func TestRandBytes(t *testing.T) { + l := 243 + b := RandBytes(l) + assert.Equal(t, l, len(b)) +} + +func TestRandIntn(t *testing.T) { + n := 243 + for i := 0; i < 100; i++ { + x := RandIntn(n) + assert.True(t, x < n) + } +} + +// It is essential that these tests run and never repeat their outputs +// lest we've been pwned and the behavior of our randomness is controlled. +// See Issues: +// * https://github.com/tendermint/tmlibs/issues/99 +// * https://github.com/tendermint/tendermint/issues/973 +func TestUniqueRng(t *testing.T) { + buf := new(bytes.Buffer) + outputs := make(map[string][]int) + for i := 0; i < 100; i++ { + testThemAll(buf) + output := buf.String() + buf.Reset() + runs, seen := outputs[output] + if seen { + t.Errorf("Run #%d's output was already seen in previous runs: %v", i, runs) + } + outputs[output] = append(outputs[output], i) + } +} + +func testThemAll(out io.Writer) { + // Reset the internal PRNG + reset() + + // Set math/rand's Seed so that any direct invocations + // of math/rand will reveal themselves. + mrand.Seed(1) + perm := RandPerm(10) + blob, _ := json.Marshal(perm) + fmt.Fprintf(out, "perm: %s\n", blob) + + fmt.Fprintf(out, "randInt: %d\n", RandInt()) + fmt.Fprintf(out, "randUint: %d\n", RandUint()) + fmt.Fprintf(out, "randIntn: %d\n", RandIntn(97)) + fmt.Fprintf(out, "randInt31: %d\n", RandInt31()) + fmt.Fprintf(out, "randInt32: %d\n", RandInt32()) + fmt.Fprintf(out, "randInt63: %d\n", RandInt63()) + fmt.Fprintf(out, "randInt64: %d\n", RandInt64()) + fmt.Fprintf(out, "randUint32: %d\n", RandUint32()) + fmt.Fprintf(out, "randUint64: %d\n", RandUint64()) + fmt.Fprintf(out, "randUint16Exp: %d\n", RandUint16Exp()) + fmt.Fprintf(out, "randUint32Exp: %d\n", RandUint32Exp()) + fmt.Fprintf(out, "randUint64Exp: %d\n", RandUint64Exp()) +} + +func TestRngConcurrencySafety(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + _ = RandUint64() + <-time.After(time.Millisecond * time.Duration(RandIntn(100))) + _ = RandPerm(3) + }() + } + wg.Wait() +} + +func BenchmarkRandBytes10B(b *testing.B) { + benchmarkRandBytes(b, 10) +} +func BenchmarkRandBytes100B(b *testing.B) { + benchmarkRandBytes(b, 100) +} +func BenchmarkRandBytes1KiB(b *testing.B) { + benchmarkRandBytes(b, 1024) +} +func BenchmarkRandBytes10KiB(b *testing.B) { + benchmarkRandBytes(b, 10*1024) +} +func BenchmarkRandBytes100KiB(b *testing.B) { + benchmarkRandBytes(b, 100*1024) +} +func BenchmarkRandBytes1MiB(b *testing.B) { + benchmarkRandBytes(b, 1024*1024) +} + +func benchmarkRandBytes(b *testing.B, n int) { + for i := 0; i < b.N; i++ { + _ = RandBytes(n) + } + b.ReportAllocs() +} diff --git a/common/repeat_timer.go b/common/repeat_timer.go index d7d9154d4..2e6cb81c8 100644 --- a/common/repeat_timer.go +++ b/common/repeat_timer.go @@ -5,82 +5,224 @@ import ( "time" ) -/* -RepeatTimer repeatedly sends a struct{}{} to .Ch after each "dur" period. -It's good for keeping connections alive. -A RepeatTimer must be Stop()'d or it will keep a goroutine alive. -*/ -type RepeatTimer struct { - Ch chan time.Time +// Used by RepeatTimer the first time, +// and every time it's Reset() after Stop(). +type TickerMaker func(dur time.Duration) Ticker - mtx sync.Mutex - name string - ticker *time.Ticker - quit chan struct{} - wg *sync.WaitGroup - dur time.Duration +// Ticker is a basic ticker interface. +type Ticker interface { + + // Never changes, never closes. + Chan() <-chan time.Time + + // Stopping a stopped Ticker will panic. + Stop() } -func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer { - var t = &RepeatTimer{ - Ch: make(chan time.Time), - ticker: time.NewTicker(dur), - quit: make(chan struct{}), - wg: new(sync.WaitGroup), - name: name, - dur: dur, +//---------------------------------------- +// defaultTickerMaker + +func defaultTickerMaker(dur time.Duration) Ticker { + ticker := time.NewTicker(dur) + return (*defaultTicker)(ticker) +} + +type defaultTicker time.Ticker + +// Implements Ticker +func (t *defaultTicker) Chan() <-chan time.Time { + return t.C +} + +// Implements Ticker +func (t *defaultTicker) Stop() { + ((*time.Ticker)(t)).Stop() +} + +//---------------------------------------- +// LogicalTickerMaker + +// Construct a TickerMaker that always uses `source`. +// It's useful for simulating a deterministic clock. +func NewLogicalTickerMaker(source chan time.Time) TickerMaker { + return func(dur time.Duration) Ticker { + return newLogicalTicker(source, dur) } - t.wg.Add(1) - go t.fireRoutine(t.ticker) +} + +type logicalTicker struct { + source <-chan time.Time + ch chan time.Time + quit chan struct{} +} + +func newLogicalTicker(source <-chan time.Time, interval time.Duration) Ticker { + lt := &logicalTicker{ + source: source, + ch: make(chan time.Time), + quit: make(chan struct{}), + } + go lt.fireRoutine(interval) + return lt +} + +// We need a goroutine to read times from t.source +// and fire on t.Chan() when `interval` has passed. +func (t *logicalTicker) fireRoutine(interval time.Duration) { + source := t.source + + // Init `lasttime` + lasttime := time.Time{} + select { + case lasttime = <-source: + case <-t.quit: + return + } + // Init `lasttime` end + + timeleft := interval + for { + select { + case newtime := <-source: + elapsed := newtime.Sub(lasttime) + timeleft -= elapsed + if timeleft <= 0 { + // Block for determinism until the ticker is stopped. + select { + case t.ch <- newtime: + case <-t.quit: + return + } + // Reset timeleft. + // Don't try to "catch up" by sending more. + // "Ticker adjusts the intervals or drops ticks to make up for + // slow receivers" - https://golang.org/pkg/time/#Ticker + timeleft = interval + } + case <-t.quit: + return // done + } + } +} + +// Implements Ticker +func (t *logicalTicker) Chan() <-chan time.Time { + return t.ch // immutable +} + +// Implements Ticker +func (t *logicalTicker) Stop() { + close(t.quit) // it *should* panic when stopped twice. +} + +//--------------------------------------------------------------------- + +/* + RepeatTimer repeatedly sends a struct{}{} to `.Chan()` after each `dur` + period. (It's good for keeping connections alive.) + A RepeatTimer must be stopped, or it will keep a goroutine alive. +*/ +type RepeatTimer struct { + name string + ch chan time.Time + tm TickerMaker + + mtx sync.Mutex + dur time.Duration + ticker Ticker + quit chan struct{} +} + +// NewRepeatTimer returns a RepeatTimer with a defaultTicker. +func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer { + return NewRepeatTimerWithTickerMaker(name, dur, defaultTickerMaker) +} + +// NewRepeatTimerWithTicker returns a RepeatTimer with the given ticker +// maker. +func NewRepeatTimerWithTickerMaker(name string, dur time.Duration, tm TickerMaker) *RepeatTimer { + var t = &RepeatTimer{ + name: name, + ch: make(chan time.Time), + tm: tm, + dur: dur, + ticker: nil, + quit: nil, + } + t.reset() return t } -func (t *RepeatTimer) fireRoutine(ticker *time.Ticker) { +func (t *RepeatTimer) fireRoutine(ch <-chan time.Time, quit <-chan struct{}) { for { select { - case t_ := <-ticker.C: - t.Ch <- t_ - case <-t.quit: - // needed so we know when we can reset t.quit - t.wg.Done() + case t_ := <-ch: + t.ch <- t_ + case <-quit: // NOTE: `t.quit` races. return } } } -// Wait the duration again before firing. -func (t *RepeatTimer) Reset() { - t.Stop() - - t.mtx.Lock() // Lock - defer t.mtx.Unlock() - - t.ticker = time.NewTicker(t.dur) - t.quit = make(chan struct{}) - t.wg.Add(1) - go t.fireRoutine(t.ticker) +func (t *RepeatTimer) Chan() <-chan time.Time { + return t.ch } -// For ease of .Stop()'ing services before .Start()'ing them, -// we ignore .Stop()'s on nil RepeatTimers. -func (t *RepeatTimer) Stop() bool { - if t == nil { - return false - } - t.mtx.Lock() // Lock +func (t *RepeatTimer) Stop() { + t.mtx.Lock() defer t.mtx.Unlock() - exists := t.ticker != nil - if exists { - t.ticker.Stop() // does not close the channel + t.stop() +} + +// Wait the duration again before firing. +func (t *RepeatTimer) Reset() { + t.mtx.Lock() + defer t.mtx.Unlock() + + t.reset() +} + +//---------------------------------------- +// Misc. + +// CONTRACT: (non-constructor) caller should hold t.mtx. +func (t *RepeatTimer) reset() { + if t.ticker != nil { + t.stop() + } + t.ticker = t.tm(t.dur) + t.quit = make(chan struct{}) + go t.fireRoutine(t.ticker.Chan(), t.quit) +} + +// CONTRACT: caller should hold t.mtx. +func (t *RepeatTimer) stop() { + if t.ticker == nil { + /* + Similar to the case of closing channels twice: + https://groups.google.com/forum/#!topic/golang-nuts/rhxMiNmRAPk + Stopping a RepeatTimer twice implies that you do + not know whether you are done or not. + If you're calling stop on a stopped RepeatTimer, + you probably have race conditions. + */ + panic("Tried to stop a stopped RepeatTimer") + } + t.ticker.Stop() + t.ticker = nil + /* + XXX + From https://golang.org/pkg/time/#Ticker: + "Stop the ticker to release associated resources" + "After Stop, no more ticks will be sent" + So we shouldn't have to do the below. + select { - case <-t.Ch: + case <-t.ch: // read off channel if there's anything there default: } - close(t.quit) - t.wg.Wait() // must wait for quit to close else we race Reset - t.ticker = nil - } - return exists + */ + close(t.quit) } diff --git a/common/repeat_timer_test.go b/common/repeat_timer_test.go index 9f03f41df..5a3a4c0a6 100644 --- a/common/repeat_timer_test.go +++ b/common/repeat_timer_test.go @@ -1,78 +1,92 @@ package common import ( - "sync" "testing" "time" - // make govet noshadow happy... - asrt "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type rCounter struct { - input chan time.Time - mtx sync.Mutex - count int +func TestDefaultTicker(t *testing.T) { + ticker := defaultTickerMaker(time.Millisecond * 10) + <-ticker.Chan() + ticker.Stop() } -func (c *rCounter) Increment() { - c.mtx.Lock() - c.count++ - c.mtx.Unlock() -} +func TestRepeat(t *testing.T) { -func (c *rCounter) Count() int { - c.mtx.Lock() - val := c.count - c.mtx.Unlock() - return val -} + ch := make(chan time.Time, 100) + lt := time.Time{} // zero time is year 1 -// Read should run in a go-routine and -// updates count by one every time a packet comes in -func (c *rCounter) Read() { - for range c.input { - c.Increment() + // tick fires `cnt` times for each second. + tick := func(cnt int) { + for i := 0; i < cnt; i++ { + lt = lt.Add(time.Second) + ch <- lt + } } -} -func TestRepeat(test *testing.T) { - assert := asrt.New(test) - - dur := time.Duration(50) * time.Millisecond - short := time.Duration(20) * time.Millisecond - // delay waits for cnt durations, an a little extra - delay := func(cnt int) time.Duration { - return time.Duration(cnt)*dur + time.Millisecond + // tock consumes Ticker.Chan() events `cnt` times. + tock := func(t *testing.T, rt *RepeatTimer, cnt int) { + for i := 0; i < cnt; i++ { + timeout := time.After(time.Second * 10) + select { + case <-rt.Chan(): + case <-timeout: + panic("expected RepeatTimer to fire") + } + } + done := true + select { + case <-rt.Chan(): + done = false + default: + } + assert.True(t, done) } - t := NewRepeatTimer("bar", dur) - // start at 0 - c := &rCounter{input: t.Ch} - go c.Read() - assert.Equal(0, c.Count()) + tm := NewLogicalTickerMaker(ch) + dur := time.Duration(10 * time.Millisecond) // less than a second + rt := NewRepeatTimerWithTickerMaker("bar", dur, tm) - // wait for 4 periods - time.Sleep(delay(4)) - assert.Equal(4, c.Count()) + // Start at 0. + tock(t, rt, 0) + tick(1) // init time - // keep reseting leads to no firing + tock(t, rt, 0) + tick(1) // wait 1 periods + tock(t, rt, 1) + tick(2) // wait 2 periods + tock(t, rt, 2) + tick(3) // wait 3 periods + tock(t, rt, 3) + tick(4) // wait 4 periods + tock(t, rt, 4) + + // Multiple resets leads to no firing. for i := 0; i < 20; i++ { - time.Sleep(short) - t.Reset() + time.Sleep(time.Millisecond) + rt.Reset() } - assert.Equal(4, c.Count()) - // after this, it still works normal - time.Sleep(delay(2)) - assert.Equal(6, c.Count()) + // After this, it works as new. + tock(t, rt, 0) + tick(1) // init time - // after a stop, nothing more is sent - stopped := t.Stop() - assert.True(stopped) - time.Sleep(delay(7)) - assert.Equal(6, c.Count()) + tock(t, rt, 0) + tick(1) // wait 1 periods + tock(t, rt, 1) + tick(2) // wait 2 periods + tock(t, rt, 2) + tick(3) // wait 3 periods + tock(t, rt, 3) + tick(4) // wait 4 periods + tock(t, rt, 4) - // close channel to stop counter - close(t.Ch) + // After a stop, nothing more is sent. + rt.Stop() + tock(t, rt, 0) + + // Another stop panics. + assert.Panics(t, func() { rt.Stop() }) } diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 52b8361f8..54a4b8aed 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -13,6 +13,8 @@ package pubsub import ( "context" + "errors" + "sync" cmn "github.com/tendermint/tmlibs/common" ) @@ -38,6 +40,7 @@ type cmd struct { // Query defines an interface for a query to be used for subscribing. type Query interface { Matches(tags map[string]interface{}) bool + String() string } // Server allows clients to subscribe/unsubscribe for messages, publishing @@ -47,6 +50,9 @@ type Server struct { cmds chan cmd cmdsCap int + + mtx sync.RWMutex + subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{} } // Option sets a parameter for the server. @@ -56,7 +62,9 @@ type Option func(*Server) // for a detailed description of how to configure buffering. If no options are // provided, the resulting server's queue is unbuffered. func NewServer(options ...Option) *Server { - s := &Server{} + s := &Server{ + subscriptions: make(map[string]map[string]struct{}), + } s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) for _, option := range options { @@ -82,17 +90,33 @@ func BufferCapacity(cap int) Option { } // BufferCapacity returns capacity of the internal server's queue. -func (s Server) BufferCapacity() int { +func (s *Server) BufferCapacity() int { return s.cmdsCap } // Subscribe creates a subscription for the given client. It accepts a channel -// on which messages matching the given query can be received. If the -// subscription already exists, the old channel will be closed. An error will -// be returned to the caller if the context is canceled. +// on which messages matching the given query can be received. An error will be +// returned to the caller if the context is canceled or if subscription already +// exist for pair clientID and query. func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error { + s.mtx.RLock() + clientSubscriptions, ok := s.subscriptions[clientID] + if ok { + _, ok = clientSubscriptions[query.String()] + } + s.mtx.RUnlock() + if ok { + return errors.New("already subscribed") + } + select { case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: + s.mtx.Lock() + if _, ok = s.subscriptions[clientID]; !ok { + s.subscriptions[clientID] = make(map[string]struct{}) + } + s.subscriptions[clientID][query.String()] = struct{}{} + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -100,10 +124,24 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou } // Unsubscribe removes the subscription on the given query. An error will be -// returned to the caller if the context is canceled. +// returned to the caller if the context is canceled or if subscription does +// not exist. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { + s.mtx.RLock() + clientSubscriptions, ok := s.subscriptions[clientID] + if ok { + _, ok = clientSubscriptions[query.String()] + } + s.mtx.RUnlock() + if !ok { + return errors.New("subscription not found") + } + select { case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: + s.mtx.Lock() + delete(clientSubscriptions, query.String()) + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -111,10 +149,20 @@ func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) } // UnsubscribeAll removes all client subscriptions. An error will be returned -// to the caller if the context is canceled. +// to the caller if the context is canceled or if subscription does not exist. func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { + s.mtx.RLock() + _, ok := s.subscriptions[clientID] + s.mtx.RUnlock() + if !ok { + return errors.New("subscription not found") + } + select { case s.cmds <- cmd{op: unsub, clientID: clientID}: + s.mtx.Lock() + delete(s.subscriptions, clientID) + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -186,13 +234,8 @@ loop: func (state *state) add(clientID string, q Query, ch chan<- interface{}) { // add query if needed - if clientToChannelMap, ok := state.queries[q]; !ok { + if _, ok := state.queries[q]; !ok { state.queries[q] = make(map[string]chan<- interface{}) - } else { - // check if already subscribed - if oldCh, ok := clientToChannelMap[clientID]; ok { - close(oldCh) - } } // create subscription diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 7bf7b41f7..84b6aa218 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -86,14 +86,11 @@ func TestClientSubscribesTwice(t *testing.T) { ch2 := make(chan interface{}, 1) err = s.Subscribe(ctx, clientID, q, ch2) - require.NoError(t, err) - - _, ok := <-ch1 - assert.False(t, ok) + require.Error(t, err) err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"}) require.NoError(t, err) - assertReceive(t, "Spider-Man", ch2) + assertReceive(t, "Spider-Man", ch1) } func TestUnsubscribe(t *testing.T) { @@ -117,6 +114,27 @@ func TestUnsubscribe(t *testing.T) { assert.False(t, ok) } +func TestResubscribe(t *testing.T) { + s := pubsub.NewServer() + s.SetLogger(log.TestingLogger()) + s.Start() + defer s.Stop() + + ctx := context.Background() + ch := make(chan interface{}) + err := s.Subscribe(ctx, clientID, query.Empty{}, ch) + require.NoError(t, err) + err = s.Unsubscribe(ctx, clientID, query.Empty{}) + require.NoError(t, err) + ch = make(chan interface{}) + err = s.Subscribe(ctx, clientID, query.Empty{}, ch) + require.NoError(t, err) + + err = s.Publish(ctx, "Cable") + require.NoError(t, err) + assertReceive(t, "Cable", ch) +} + func TestUnsubscribeAll(t *testing.T) { s := pubsub.NewServer() s.SetLogger(log.TestingLogger()) @@ -125,9 +143,9 @@ func TestUnsubscribeAll(t *testing.T) { ctx := context.Background() ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1) - err := s.Subscribe(ctx, clientID, query.Empty{}, ch1) + err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch1) require.NoError(t, err) - err = s.Subscribe(ctx, clientID, query.Empty{}, ch2) + err = s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'"), ch2) require.NoError(t, err) err = s.UnsubscribeAll(ctx, clientID) diff --git a/test.sh b/test.sh index 02bdaae86..b3978d3fe 100755 --- a/test.sh +++ b/test.sh @@ -2,14 +2,14 @@ set -e # run the linter -make metalinter_test +# make metalinter_test # run the unit tests with coverage echo "" > coverage.txt for d in $(go list ./... | grep -v vendor); do - go test -race -coverprofile=profile.out -covermode=atomic "$d" - if [ -f profile.out ]; then - cat profile.out >> coverage.txt - rm profile.out - fi + go test -race -coverprofile=profile.out -covermode=atomic "$d" + if [ -f profile.out ]; then + cat profile.out >> coverage.txt + rm profile.out + fi done diff --git a/version/version.go b/version/version.go index 45222da79..6cc887286 100644 --- a/version/version.go +++ b/version/version.go @@ -1,3 +1,3 @@ package version -const Version = "0.5.0" +const Version = "0.6.0"