From 54cc5100f86105d11a8850d178368f183206392e Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 30 Jan 2019 20:32:53 +0400 Subject: [PATCH] nope --- consensus/byzantine_test.go | 3 +- consensus/common_test.go | 144 ++++++++++++++----------------- consensus/mempool_test.go | 6 +- consensus/reactor_test.go | 46 ++++------ consensus/replay.go | 7 +- consensus/replay_file.go | 3 +- libs/pubsub/pubsub.go | 12 ++- libs/pubsub/pubsub_test.go | 2 +- libs/pubsub/subscription.go | 14 ++- node/node.go | 6 +- node/node_test.go | 7 +- rpc/client/helpers.go | 22 ++--- rpc/client/localclient.go | 4 +- rpc/core/events.go | 20 ++--- rpc/core/mempool.go | 4 +- rpc/lib/server/handlers.go | 23 ++--- rpc/lib/types/types.go | 11 --- state/execution_test.go | 8 +- state/txindex/indexer_service.go | 8 +- types/event_bus.go | 10 ++- types/event_bus_test.go | 22 ++--- 21 files changed, 170 insertions(+), 212 deletions(-) diff --git a/consensus/byzantine_test.go b/consensus/byzantine_test.go index 5ca052d09..862e24ded 100644 --- a/consensus/byzantine_test.go +++ b/consensus/byzantine_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" cmn "github.com/tendermint/tendermint/libs/common" - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" "github.com/tendermint/tendermint/p2p" "github.com/tendermint/tendermint/types" ) @@ -50,7 +49,7 @@ func TestByzantine(t *testing.T) { switches[i].SetLogger(p2pLogger.With("validator", i)) } - eventSubs := make([]*tmpubsub.Subscription, N) + eventSubs := make([]types.Subscription, N) reactors := make([]p2p.Reactor, N) for i := 0; i < N; i++ { // make first val byzantine diff --git a/consensus/common_test.go b/consensus/common_test.go index 49834d6c4..cecc01509 100644 --- a/consensus/common_test.go +++ b/consensus/common_test.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "os" "path/filepath" - "reflect" "sort" "sync" "testing" @@ -220,22 +219,22 @@ func validatePrevoteAndPrecommit(t *testing.T, cs *ConsensusState, thisRound, lo } // genesis -func subscribeToVoter(cs *ConsensusState, addr []byte) chan interface{} { +func subscribeToVoter(cs *ConsensusState, addr []byte) <-chan tmpubsub.MsgAndTags { voteCh0Sub, err := cs.eventBus.Subscribe(context.Background(), testSubscriber, types.EventQueryVote) if err != nil { panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote)) } - voteCh := make(chan interface{}) + ch := make(chan tmpubsub.MsgAndTags) go func() { - for msgAndTags := range voteCh0Sub.Out() { - vote := msgAndTags.Msg.(types.EventDataVote) + for mt := range voteCh0Sub.Out() { + vote := mt.Msg().(types.EventDataVote) // we only fire for our own votes if bytes.Equal(addr, vote.Vote.ValidatorAddress) { - voteCh <- msgAndTags.Msg + ch <- mt } } }() - return voteCh + return ch } //------------------------------------------------------------------------------- @@ -350,29 +349,21 @@ func ensureNoNewTimeout(stepCh <-chan tmpubsub.MsgAndTags, timeout int64) { "We should be stuck waiting, not receiving NewTimeout event") } -func ensureNewEvent( - ch <-chan tmpubsub.MsgAndTags, - height int64, - round int, - timeout time.Duration, - errorMessage string) { - +func ensureNewEvent(ch <-chan tmpubsub.MsgAndTags, height int64, round int, timeout time.Duration, errorMessage string) { select { case <-time.After(timeout): panic(errorMessage) - case ev := <-ch: - rs, ok := ev.Msg.(types.EventDataRoundState) + case mt := <-ch: + roundStateEvent, ok := mt.Msg().(types.EventDataRoundState) if !ok { - panic( - fmt.Sprintf( - "expected a EventDataRoundState, got %v.Wrong subscription channel?", - reflect.TypeOf(rs))) + panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?", + mt.Msg())) } - if rs.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, rs.Height)) + if roundStateEvent.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height)) } - if rs.Round != round { - panic(fmt.Sprintf("expected round %v, got %v", round, rs.Round)) + if roundStateEvent.Round != round { + panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round)) } // TODO: We could check also for a step at this point! } @@ -382,19 +373,17 @@ func ensureNewRound(roundCh <-chan tmpubsub.MsgAndTags, height int64, round int) select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewRound event") - case ev := <-roundCh: - rs, ok := ev.Msg.(types.EventDataNewRound) + case mt := <-roundCh: + newRoundEvent, ok := mt.Msg().(types.EventDataNewRound) if !ok { - panic( - fmt.Sprintf( - "expected a EventDataNewRound, got %v.Wrong subscription channel?", - reflect.TypeOf(rs))) + panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?", + mt.Msg())) } - if rs.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, rs.Height)) + if newRoundEvent.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height)) } - if rs.Round != round { - panic(fmt.Sprintf("expected round %v, got %v", round, rs.Round)) + if newRoundEvent.Round != round { + panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round)) } } } @@ -409,19 +398,17 @@ func ensureNewProposal(proposalCh <-chan tmpubsub.MsgAndTags, height int64, roun select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewProposal event") - case ev := <-proposalCh: - rs, ok := ev.Msg.(types.EventDataCompleteProposal) + case mt := <-proposalCh: + proposalEvent, ok := mt.Msg().(types.EventDataCompleteProposal) if !ok { - panic( - fmt.Sprintf( - "expected a EventDataCompleteProposal, got %v.Wrong subscription channel?", - reflect.TypeOf(rs))) + panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?", + mt.Msg())) } - if rs.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, rs.Height)) + if proposalEvent.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) } - if rs.Round != round { - panic(fmt.Sprintf("expected round %v, got %v", round, rs.Round)) + if proposalEvent.Round != round { + panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) } } } @@ -435,15 +422,14 @@ func ensureNewBlock(blockCh <-chan tmpubsub.MsgAndTags, height int64) { select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewBlock event") - case ev := <-blockCh: - block, ok := ev.Msg.(types.EventDataNewBlock) + case mt := <-blockCh: + blockEvent, ok := mt.Msg().(types.EventDataNewBlock) if !ok { - panic(fmt.Sprintf("expected a *types.EventDataNewBlock, "+ - "got %v. wrong subscription channel?", - reflect.TypeOf(block))) + panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?", + mt.Msg())) } - if block.Block.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, block.Block.Height)) + if blockEvent.Block.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height)) } } } @@ -452,18 +438,17 @@ func ensureNewBlockHeader(blockCh <-chan tmpubsub.MsgAndTags, height int64, bloc select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewBlockHeader event") - case ev := <-blockCh: - blockHeader, ok := ev.Msg.(types.EventDataNewBlockHeader) + case mt := <-blockCh: + blockHeaderEvent, ok := mt.Msg().(types.EventDataNewBlockHeader) if !ok { - panic(fmt.Sprintf("expected a *types.EventDataNewBlockHeader, "+ - "got %v. wrong subscription channel?", - reflect.TypeOf(blockHeader))) + panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?", + mt.Msg())) } - if blockHeader.Header.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, blockHeader.Header.Height)) + if blockHeaderEvent.Header.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height)) } - if !bytes.Equal(blockHeader.Header.Hash(), blockHash) { - panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeader.Header.Hash())) + if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) { + panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash())) } } } @@ -473,51 +458,48 @@ func ensureNewUnlock(unlockCh <-chan tmpubsub.MsgAndTags, height int64, round in "Timeout expired while waiting for NewUnlock event") } -func ensureProposal(proposalCh <-chan tmpubsub.MsgAndTags, height int64, round int, propId types.BlockID) { +func ensureProposal(proposalCh <-chan tmpubsub.MsgAndTags, height int64, round int, propID types.BlockID) { select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewProposal event") - case ev := <-proposalCh: - rs, ok := ev.Msg.(types.EventDataCompleteProposal) + case mt := <-proposalCh: + proposalEvent, ok := mt.Msg().(types.EventDataCompleteProposal) if !ok { - panic( - fmt.Sprintf( - "expected a EventDataCompleteProposal, got %v.Wrong subscription channel?", - reflect.TypeOf(rs))) + panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?", + mt.Msg())) } - if rs.Height != height { - panic(fmt.Sprintf("expected height %v, got %v", height, rs.Height)) + if proposalEvent.Height != height { + panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) } - if rs.Round != round { - panic(fmt.Sprintf("expected round %v, got %v", round, rs.Round)) + if proposalEvent.Round != round { + panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) } - if !rs.BlockID.Equals(propId) { + if !proposalEvent.BlockID.Equals(propID) { panic("Proposed block does not match expected block") } } } -func ensurePrecommit(voteCh <-chan interface{}, height int64, round int) { +func ensurePrecommit(voteCh <-chan tmpubsub.MsgAndTags, height int64, round int) { ensureVote(voteCh, height, round, types.PrecommitType) } -func ensurePrevote(voteCh <-chan interface{}, height int64, round int) { +func ensurePrevote(voteCh <-chan tmpubsub.MsgAndTags, height int64, round int) { ensureVote(voteCh, height, round, types.PrevoteType) } -func ensureVote(voteCh <-chan interface{}, height int64, round int, +func ensureVote(voteCh <-chan tmpubsub.MsgAndTags, height int64, round int, voteType types.SignedMsgType) { select { case <-time.After(ensureTimeout): panic("Timeout expired while waiting for NewVote event") - case v := <-voteCh: - edv, ok := v.(types.EventDataVote) + case mt := <-voteCh: + voteEvent, ok := mt.Msg().(types.EventDataVote) if !ok { - panic(fmt.Sprintf("expected a *types.Vote, "+ - "got %v. wrong subscription channel?", - reflect.TypeOf(v))) + panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?", + mt.Msg())) } - vote := edv.Vote + vote := voteEvent.Vote if vote.Height != height { panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height)) } diff --git a/consensus/mempool_test.go b/consensus/mempool_test.go index c6dd3779d..31fd6c248 100644 --- a/consensus/mempool_test.go +++ b/consensus/mempool_test.go @@ -117,9 +117,9 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { for nTxs := 0; nTxs < NTxs; { ticker := time.NewTicker(time.Second * 30) select { - case b := <-newBlockCh: - evt := b.Msg.(types.EventDataNewBlock) - nTxs += int(evt.Block.Header.NumTxs) + case mt := <-newBlockCh: + blockEvent := mt.Msg().(types.EventDataNewBlock) + nTxs += int(blockEvent.Block.Header.NumTxs) case <-ticker.C: panic("Timed out waiting to commit blocks with transactions") } diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index 9c6df4261..382ff0cc5 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -21,7 +21,6 @@ import ( cfg "github.com/tendermint/tendermint/config" dbm "github.com/tendermint/tendermint/libs/db" "github.com/tendermint/tendermint/libs/log" - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" mempl "github.com/tendermint/tendermint/mempool" "github.com/tendermint/tendermint/p2p" sm "github.com/tendermint/tendermint/state" @@ -37,12 +36,11 @@ func init() { func startConsensusNet(t *testing.T, css []*ConsensusState, N int) ( []*ConsensusReactor, - []*tmpubsub.Subscription, + []types.Subscription, []*types.EventBus, ) { - var err error reactors := make([]*ConsensusReactor, N) - eventSubs := make([]*tmpubsub.Subscription, N) + blocksSubs := make([]types.Subscription, 0) eventBuses := make([]*types.EventBus, N) for i := 0; i < N; i++ { /*logger, err := tmflags.ParseLogLevel("consensus:info,*:error", logger, "info") @@ -54,8 +52,9 @@ func startConsensusNet(t *testing.T, css []*ConsensusState, N int) ( eventBuses[i] = css[i].eventBus reactors[i].SetEventBus(eventBuses[i]) - eventSubs[i], err = eventBuses[i].Subscribe(context.Background(), testSubscriber, types.EventQueryNewBlock) + blocksSub, err := eventBuses[i].Subscribe(context.Background(), testSubscriber, types.EventQueryNewBlock) require.NoError(t, err) + blocksSubs = append(blocksSubs, blocksSub) } // make connected switches and start all reactors p2p.MakeConnectedSwitches(config.P2P, N, func(i int, s *p2p.Switch) *p2p.Switch { @@ -72,7 +71,7 @@ func startConsensusNet(t *testing.T, css []*ConsensusState, N int) ( s := reactors[i].conS.GetState() reactors[i].SwitchToConsensus(s, 0) } - return reactors, eventSubs, eventBuses + return reactors, blocksSubs, eventBuses } func stopConsensusNet(logger log.Logger, reactors []*ConsensusReactor, eventBuses []*types.EventBus) { @@ -173,15 +172,15 @@ func TestReactorWithEvidence(t *testing.T) { // wait till everyone makes the first new block with no evidence timeoutWaitGroup(t, nValidators, func(j int) { - blockI := <-eventSubs[j].Out() - block := blockI.Msg.(types.EventDataNewBlock).Block + mt := <-eventSubs[j].Out() + block := mt.Msg().(types.EventDataNewBlock).Block assert.True(t, len(block.Evidence.Evidence) == 0) }, css) // second block should have evidence timeoutWaitGroup(t, nValidators, func(j int) { - blockI := <-eventSubs[j].Out() - block := blockI.Msg.(types.EventDataNewBlock).Block + mt := <-eventSubs[j].Out() + block := mt.Msg().(types.EventDataNewBlock).Block assert.True(t, len(block.Evidence.Evidence) > 0) }, css) } @@ -445,17 +444,14 @@ func waitForAndValidateBlock( t *testing.T, n int, activeVals map[string]struct{}, - eventSubs []*tmpubsub.Subscription, + eventSubs []types.Subscription, css []*ConsensusState, txs ...[]byte, ) { timeoutWaitGroup(t, n, func(j int) { css[j].Logger.Debug("waitForAndValidateBlock") - newBlockI, ok := <-eventSubs[j].Out() - if !ok { - return - } - newBlock := newBlockI.Msg.(types.EventDataNewBlock).Block + mt := <-eventSubs[j].Out() + newBlock := mt.Msg().(types.EventDataNewBlock).Block css[j].Logger.Debug("waitForAndValidateBlock: Got block", "height", newBlock.Height) err := validateBlock(newBlock, activeVals) assert.Nil(t, err) @@ -470,7 +466,7 @@ func waitForAndValidateBlockWithTx( t *testing.T, n int, activeVals map[string]struct{}, - eventSubs []*tmpubsub.Subscription, + eventSubs []types.Subscription, css []*ConsensusState, txs ...[]byte, ) { @@ -479,11 +475,8 @@ func waitForAndValidateBlockWithTx( BLOCK_TX_LOOP: for { css[j].Logger.Debug("waitForAndValidateBlockWithTx", "ntxs", ntxs) - newBlockI, ok := <-eventSubs[j].Out() - if !ok { - return - } - newBlock := newBlockI.Msg.(types.EventDataNewBlock).Block + mt := <-eventSubs[j].Out() + newBlock := mt.Msg().(types.EventDataNewBlock).Block css[j].Logger.Debug("waitForAndValidateBlockWithTx: Got block", "height", newBlock.Height) err := validateBlock(newBlock, activeVals) assert.Nil(t, err) @@ -508,7 +501,7 @@ func waitForBlockWithUpdatedValsAndValidateIt( t *testing.T, n int, updatedVals map[string]struct{}, - eventSubs []*tmpubsub.Subscription, + eventSubs []types.Subscription, css []*ConsensusState, ) { timeoutWaitGroup(t, n, func(j int) { @@ -517,11 +510,8 @@ func waitForBlockWithUpdatedValsAndValidateIt( LOOP: for { css[j].Logger.Debug("waitForBlockWithUpdatedValsAndValidateIt") - newBlockI, ok := <-eventSubs[j].Out() - if !ok { - return - } - newBlock = newBlockI.Msg.(types.EventDataNewBlock).Block + mt := <-eventSubs[j].Out() + newBlock = mt.Msg().(types.EventDataNewBlock).Block if newBlock.LastCommit.Size() == len(updatedVals) { css[j].Logger.Debug("waitForBlockWithUpdatedValsAndValidateIt: Got block", "height", newBlock.Height) break LOOP diff --git a/consensus/replay.go b/consensus/replay.go index bdf9fc931..f3f50cec4 100644 --- a/consensus/replay.go +++ b/consensus/replay.go @@ -17,7 +17,6 @@ import ( dbm "github.com/tendermint/tendermint/libs/db" "github.com/tendermint/tendermint/libs/log" - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" "github.com/tendermint/tendermint/proxy" sm "github.com/tendermint/tendermint/state" "github.com/tendermint/tendermint/types" @@ -43,7 +42,7 @@ var crc32c = crc32.MakeTable(crc32.Castagnoli) // Unmarshal and apply a single message to the consensus state as if it were // received in receiveRoutine. Lines that start with "#" are ignored. // NOTE: receiveRoutine should not be running. -func (cs *ConsensusState) readReplayMessage(msg *TimedWALMessage, newStepSub *tmpubsub.Subscription) error { +func (cs *ConsensusState) readReplayMessage(msg *TimedWALMessage, newStepSub types.Subscription) error { // Skip meta messages which exist for demarcating boundaries. if _, ok := msg.Msg.(EndHeightMessage); ok { return nil @@ -57,8 +56,8 @@ func (cs *ConsensusState) readReplayMessage(msg *TimedWALMessage, newStepSub *tm ticker := time.After(time.Second * 2) if newStepSub != nil { select { - case mi := <-newStepSub.Out(): - m2 := mi.Msg.(types.EventDataRoundState) + case mt := <-newStepSub.Out(): + m2 := mt.Msg().(types.EventDataRoundState) if m.Height != m2.Height || m.Round != m2.Round || m.Step != m2.Step { return fmt.Errorf("RoundState mismatch. Got %v; Expected %v", m2, m) } diff --git a/consensus/replay_file.go b/consensus/replay_file.go index fa1e472d6..cd1230485 100644 --- a/consensus/replay_file.go +++ b/consensus/replay_file.go @@ -16,7 +16,6 @@ import ( cmn "github.com/tendermint/tendermint/libs/common" dbm "github.com/tendermint/tendermint/libs/db" "github.com/tendermint/tendermint/libs/log" - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" "github.com/tendermint/tendermint/proxy" sm "github.com/tendermint/tendermint/state" "github.com/tendermint/tendermint/types" @@ -122,7 +121,7 @@ func newPlayback(fileName string, fp *os.File, cs *ConsensusState, genState sm.S } // go back count steps by resetting the state and running (pb.count - count) steps -func (pb *playback) replayReset(count int, newStepSub *tmpubsub.Subscription) error { +func (pb *playback) replayReset(count int, newStepSub types.Subscription) error { pb.cs.Stop() pb.cs.Wait() diff --git a/libs/pubsub/pubsub.go b/libs/pubsub/pubsub.go index 774b11f23..f4fe7eb7f 100644 --- a/libs/pubsub/pubsub.go +++ b/libs/pubsub/pubsub.go @@ -131,10 +131,14 @@ func (s *Server) BufferCapacity() int { return s.cmdsCap } -// Subscribe creates a subscription for the given client. An error will be -// returned to the caller if the context is canceled or if subscription already -// exist for pair clientID and query. outCapacity can be used to set a -// capacity for Subscription#Out channel (1 by default). +// Subscribe creates a subscription for the given client. +// +// An error will be returned to the caller if the context is canceled or if +// subscription already exist for pair clientID and query. +// +// outCapacity can be used to set a capacity for Subscription#Out channel (1 by +// default). Panics if outCapacity is less than or equal to zero. If you want +// an unbuffered channel, use SubscribeUnbuffered. func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, outCapacity ...int) (*Subscription, error) { outCap := 1 if len(outCapacity) > 0 { diff --git a/libs/pubsub/pubsub_test.go b/libs/pubsub/pubsub_test.go index 7963be509..63a9c6e3e 100644 --- a/libs/pubsub/pubsub_test.go +++ b/libs/pubsub/pubsub_test.go @@ -306,7 +306,7 @@ func benchmarkNClientsOneQuery(n int, b *testing.B) { func assertReceive(t *testing.T, expected interface{}, ch <-chan pubsub.MsgAndTags, msgAndArgs ...interface{}) { select { case actual := <-ch: - assert.Equal(t, expected, actual.Msg, msgAndArgs...) + assert.Equal(t, expected, actual.Msg(), msgAndArgs...) case <-time.After(1 * time.Second): t.Errorf("Expected to receive %v from the channel, got nothing after 1s", expected) debug.PrintStack() diff --git a/libs/pubsub/subscription.go b/libs/pubsub/subscription.go index 3bdfb0a48..db4f00d56 100644 --- a/libs/pubsub/subscription.go +++ b/libs/pubsub/subscription.go @@ -55,6 +55,16 @@ func (s *Subscription) Err() error { // MsgAndTags glues a message and tags together. type MsgAndTags struct { - Msg interface{} - Tags TagMap + msg interface{} + tags TagMap +} + +// Msg returns a message. +func (mt MsgAndTags) Msg() interface{} { + return mt.msg +} + +// Tags returns tags. +func (mt MsgAndTags) Tags() TagMap { + return mt.tags } diff --git a/node/node.go b/node/node.go index 1b7319811..53cfc2780 100644 --- a/node/node.go +++ b/node/node.go @@ -676,7 +676,11 @@ func (n *Node) startRPC() ([]net.Listener, error) { for i, listenAddr := range listenAddrs { mux := http.NewServeMux() rpcLogger := n.Logger.With("module", "rpc-server") - wm := rpcserver.NewWebsocketManager(rpccore.Routes, coreCodec, rpcserver.EventSubscriber(n.eventBus)) + wm := rpcserver.NewWebsocketManager(rpccore.Routes, coreCodec, rpcserver.DisconnectCallback(func(remoteAddr string) { + // Unsubscribe a client upon disconnect since it won't be able to do it + // itself. + n.eventBus.UnsubscribeAll(context.TODO(), remoteAddr) + })) wm.SetLogger(rpcLogger.With("protocol", "websocket")) mux.HandleFunc("/websocket", wm.WebsocketHandler) rpcserver.RegisterRPCFuncs(mux, rpccore.Routes, coreCodec, rpcLogger) diff --git a/node/node_test.go b/node/node_test.go index 3218c8327..00b7ea244 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -41,11 +41,12 @@ func TestNodeStartStop(t *testing.T) { t.Logf("Started node %v", n.sw.NodeInfo()) // wait for the node to produce a block - blockCh := make(chan interface{}) - err = n.EventBus().Subscribe(context.Background(), "node_test", types.EventQueryNewBlock, blockCh) + blocksSub, err := n.EventBus().Subscribe(context.Background(), "node_test", types.EventQueryNewBlock) require.NoError(t, err) select { - case <-blockCh: + case <-blocksSub.Out(): + case <-blocksSub.Cancelled(): + t.Fatal("blocksSub was cancelled") case <-time.After(10 * time.Second): t.Fatal("timed out waiting for the node to produce a block") } diff --git a/rpc/client/helpers.go b/rpc/client/helpers.go index 2e80a3063..404c5d7b5 100644 --- a/rpc/client/helpers.go +++ b/rpc/client/helpers.go @@ -62,29 +62,19 @@ func WaitForOneEvent(c EventsClient, evtTyp string, timeout time.Duration) (type evts := make(chan interface{}, 1) // register for the next event of this type - query := types.QueryForEvent(evtTyp) - err := c.Subscribe(ctx, subscriber, query, evts) + sub, err := c.Subscribe(ctx, subscriber, types.QueryForEvent(evtTyp)) if err != nil { return nil, errors.Wrap(err, "failed to subscribe") } // make sure to unregister after the test is over - defer func() { - // drain evts to make sure we don't block - LOOP: - for { - select { - case <-evts: - default: - break LOOP - } - } - c.UnsubscribeAll(ctx, subscriber) - }() + defer c.UnsubscribeAll(ctx, subscriber) select { - case evt := <-evts: - return evt.(types.TMEventData), nil + case mt := <-sub.Out(): + return mt.Msg().(types.TMEventData), nil + case <-sub.Cancelled(): + return nil, errors.New("subscription was cancelled") case <-ctx.Done(): return nil, errors.New("timed out waiting for event") } diff --git a/rpc/client/localclient.go b/rpc/client/localclient.go index ba8fb3f17..33a1ce225 100644 --- a/rpc/client/localclient.go +++ b/rpc/client/localclient.go @@ -140,8 +140,8 @@ func (Local) TxSearch(query string, prove bool, page, perPage int) (*ctypes.Resu return core.TxSearch(query, prove, page, perPage) } -func (c *Local) Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, out chan<- interface{}) error { - return c.EventBus.Subscribe(ctx, subscriber, query, out) +func (c *Local) Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (types.Subscription, error) { + return c.EventBus.Subscribe(ctx, subscriber, query, outCapacity...) } func (c *Local) Unsubscribe(ctx context.Context, subscriber string, query tmpubsub.Query) error { diff --git a/rpc/core/events.go b/rpc/core/events.go index 36536bc83..92a8a7799 100644 --- a/rpc/core/events.go +++ b/rpc/core/events.go @@ -101,7 +101,7 @@ func Subscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultSubscri ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) defer cancel() - sub, err := eventBusFor(wsCtx).Subscribe(ctx, addr, q) + sub, err := eventBus.Subscribe(ctx, addr, q) if err != nil { return nil, err } @@ -109,13 +109,13 @@ func Subscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultSubscri go func() { for { select { - case event := <-sub.Out(): - tmResult := &ctypes.ResultEvent{query, event.Msg.(tmtypes.TMEventData)} + case mt := <-sub.Out(): + resultEvent := &ctypes.ResultEvent{query, mt.Msg().(tmtypes.TMEventData)} wsCtx.TryWriteRPCResponse( rpctypes.NewRPCSuccessResponse( wsCtx.Codec(), rpctypes.JSONRPCStringID(fmt.Sprintf("%v#event", wsCtx.Request.ID)), - tmResult, + resultEvent, )) case <-sub.Cancelled(): wsCtx.TryWriteRPCResponse( @@ -168,7 +168,7 @@ func Unsubscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultUnsub if err != nil { return nil, errors.Wrap(err, "failed to parse query") } - err = eventBusFor(wsCtx).Unsubscribe(context.Background(), addr, q) + err = eventBus.Unsubscribe(context.Background(), addr, q) if err != nil { return nil, err } @@ -202,17 +202,9 @@ func Unsubscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultUnsub func UnsubscribeAll(wsCtx rpctypes.WSRPCContext) (*ctypes.ResultUnsubscribe, error) { addr := wsCtx.GetRemoteAddr() logger.Info("Unsubscribe from all", "remote", addr) - err := eventBusFor(wsCtx).UnsubscribeAll(context.Background(), addr) + err := eventBus.UnsubscribeAll(context.Background(), addr) if err != nil { return nil, err } return &ctypes.ResultUnsubscribe{}, nil } - -func eventBusFor(wsCtx rpctypes.WSRPCContext) tmtypes.EventBusSubscriber { - es := wsCtx.GetEventSubscriber() - if es == nil { - es = eventBus - } - return es -} diff --git a/rpc/core/mempool.go b/rpc/core/mempool.go index 16d03ae0c..6b55c46e7 100644 --- a/rpc/core/mempool.go +++ b/rpc/core/mempool.go @@ -201,8 +201,8 @@ func BroadcastTxCommit(tx types.Tx) (*ctypes.ResultBroadcastTxCommit, error) { // TODO: configurable? var deliverTxTimeout = rpcserver.WriteTimeout / 2 select { - case deliverTxResMsg := <-deliverTxSub.Out(): // The tx was included in a block. - deliverTxRes := deliverTxResMsg.Msg.(types.EventDataTx) + case mt := <-deliverTxSub.Out(): // The tx was included in a block. + deliverTxRes := mt.Msg().(types.EventDataTx) return &ctypes.ResultBroadcastTxCommit{ CheckTx: *checkTxRes, DeliverTx: deliverTxRes.Result, diff --git a/rpc/lib/server/handlers.go b/rpc/lib/server/handlers.go index edab88fe5..b968f43c3 100644 --- a/rpc/lib/server/handlers.go +++ b/rpc/lib/server/handlers.go @@ -2,7 +2,6 @@ package rpcserver import ( "bytes" - "context" "encoding/hex" "encoding/json" "fmt" @@ -434,8 +433,8 @@ type wsConnection struct { // Send pings to server with this period. Must be less than readWait, but greater than zero. pingPeriod time.Duration - // object that is used to subscribe / unsubscribe from events - eventSub types.EventSubscriber + // see DisconnectCallback option. + disconnectCallback func(remoteAddr string) } // NewWSConnection wraps websocket.Conn. @@ -468,12 +467,11 @@ func NewWSConnection( return wsc } -// EventSubscriber sets object that is used to subscribe / unsubscribe from -// events - not Goroutine-safe. If none given, default node's eventBus will be -// used. -func EventSubscriber(eventSub types.EventSubscriber) func(*wsConnection) { +// DisconnectCallback can be used optionally to set a callback, which will be +// called upon disconnect - not Goroutine-safe. +func DisconnectCallback(cb func(remoteAddr string)) func(*wsConnection) { return func(wsc *wsConnection) { - wsc.eventSub = eventSub + wsc.disconnectCallback = cb } } @@ -526,8 +524,8 @@ func (wsc *wsConnection) OnStart() error { func (wsc *wsConnection) OnStop() { // Both read and write loops close the websocket connection when they exit their loops. // The writeChan is never closed, to allow WriteRPCResponse() to fail. - if wsc.eventSub != nil { - wsc.eventSub.UnsubscribeAll(context.TODO(), wsc.remoteAddr) + if wsc.disconnectCallback != nil { + wsc.disconnectCallback(wsc.remoteAddr) } } @@ -537,11 +535,6 @@ func (wsc *wsConnection) GetRemoteAddr() string { return wsc.remoteAddr } -// GetEventSubscriber implements WSRPCConnection by returning event subscriber. -func (wsc *wsConnection) GetEventSubscriber() types.EventSubscriber { - return wsc.eventSub -} - // WriteRPCResponse pushes a response to the writeChan, and blocks until it is accepted. // It implements WSRPCConnection. It is Goroutine-safe. func (wsc *wsConnection) WriteRPCResponse(resp types.RPCResponse) { diff --git a/rpc/lib/types/types.go b/rpc/lib/types/types.go index e0753a03b..ceb7be83a 100644 --- a/rpc/lib/types/types.go +++ b/rpc/lib/types/types.go @@ -1,7 +1,6 @@ package rpctypes import ( - "context" "encoding/json" "fmt" "reflect" @@ -10,8 +9,6 @@ import ( "github.com/pkg/errors" amino "github.com/tendermint/go-amino" - - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" ) // a wrapper to emulate a sum type: jsonrpcid = string | int @@ -240,17 +237,9 @@ type WSRPCConnection interface { GetRemoteAddr() string WriteRPCResponse(resp RPCResponse) TryWriteRPCResponse(resp RPCResponse) bool - GetEventSubscriber() EventSubscriber Codec() *amino.Codec } -// EventSubscriber mirros tendermint/tendermint/types.EventBusSubscriber -type EventSubscriber interface { - Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, out chan<- interface{}) error - Unsubscribe(ctx context.Context, subscriber string, query tmpubsub.Query) error - UnsubscribeAll(ctx context.Context, subscriber string) error -} - // websocket-only RPCFuncs take this as the first parameter. type WSRPCContext struct { Request RPCRequest diff --git a/state/execution_test.go b/state/execution_test.go index 1f1ee5e54..7faca096e 100644 --- a/state/execution_test.go +++ b/state/execution_test.go @@ -341,15 +341,15 @@ func TestEndBlockValidatorUpdates(t *testing.T) { // test we threw an event select { - case e := <-updatesSub.Out(): - event, ok := e.Msg.(types.EventDataValidatorSetUpdates) - require.True(t, ok, "Expected event of type EventDataValidatorSetUpdates, got %T", e) + case mt := <-updatesSub.Out(): + event, ok := mt.Msg().(types.EventDataValidatorSetUpdates) + require.True(t, ok, "Expected event of type EventDataValidatorSetUpdates, got %T", mt.Msg()) if assert.NotEmpty(t, event.ValidatorUpdates) { assert.Equal(t, pubkey, event.ValidatorUpdates[0].PubKey) assert.EqualValues(t, 10, event.ValidatorUpdates[0].VotingPower) } case <-updatesSub.Cancelled(): - t.Fatal("updatesSub was cancelled.") + t.Fatal(fmt.Sprintf("updatesSub was cancelled (reason: %v)", updatesSub.Err())) case <-time.After(1 * time.Second): t.Fatal("Did not receive EventValidatorSetUpdates within 1 sec.") } diff --git a/state/txindex/indexer_service.go b/state/txindex/indexer_service.go index c48f01bae..468753968 100644 --- a/state/txindex/indexer_service.go +++ b/state/txindex/indexer_service.go @@ -44,13 +44,13 @@ func (is *IndexerService) OnStart() error { go func() { for { select { - case msgAndTags := <-blockHeadersSub.Out(): - header := msgAndTags.Msg.(types.EventDataNewBlockHeader).Header + case mt := <-blockHeadersSub.Out(): + header := mt.Msg().(types.EventDataNewBlockHeader).Header batch := NewBatch(header.NumTxs) for i := int64(0); i < header.NumTxs; i++ { select { - case msgAndTags := <-txsSub.Out(): - txResult := msgAndTags.Msg.(types.EventDataTx).TxResult + case mt2 := <-txsSub.Out(): + txResult := mt2.Msg().(types.EventDataTx).TxResult batch.Add(&txResult) case <-txsSub.Cancelled(): is.Logger.Error("Failed to index a block. txsSub was cancelled. Did the Tendermint stop?", diff --git a/types/event_bus.go b/types/event_bus.go index 6b4e815ea..67a91e159 100644 --- a/types/event_bus.go +++ b/types/event_bus.go @@ -12,11 +12,17 @@ import ( const defaultCapacity = 0 type EventBusSubscriber interface { - Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (*tmpubsub.Subscription, error) + Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (Subscription, error) Unsubscribe(ctx context.Context, subscriber string, query tmpubsub.Query) error UnsubscribeAll(ctx context.Context, subscriber string) error } +type Subscription interface { + Out() <-chan tmpubsub.MsgAndTags + Cancelled() <-chan struct{} + Err() error +} + // EventBus is a common bus for all events going through the system. All calls // are proxied to underlying pubsub server. All events must be published using // EventBus to ensure correct data types. @@ -52,7 +58,7 @@ func (b *EventBus) OnStop() { b.pubsub.Stop() } -func (b *EventBus) Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (*tmpubsub.Subscription, error) { +func (b *EventBus) Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (Subscription, error) { return b.pubsub.Subscribe(ctx, subscriber, query, outCapacity...) } diff --git a/types/event_bus_test.go b/types/event_bus_test.go index 85f06237f..180301210 100644 --- a/types/event_bus_test.go +++ b/types/event_bus_test.go @@ -26,13 +26,13 @@ func TestEventBusPublishEventTx(t *testing.T) { // PublishEventTx adds all these 3 tags, so the query below should work query := fmt.Sprintf("tm.event='Tx' AND tx.height=1 AND tx.hash='%X' AND baz=1", tx.Hash()) - txEventsSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) + txsSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) require.NoError(t, err) done := make(chan struct{}) go func() { - e := <-txEventsSub.Out() - edt := e.Msg.(EventDataTx) + mt := <-txsSub.Out() + edt := mt.Msg().(EventDataTx) assert.Equal(t, int64(1), edt.Height) assert.Equal(t, uint32(0), edt.Index) assert.Equal(t, tx, edt.Tx) @@ -67,13 +67,13 @@ func TestEventBusPublishEventNewBlock(t *testing.T) { // PublishEventNewBlock adds the tm.event tag, so the query below should work query := "tm.event='NewBlock' AND baz=1 AND foz=2" - txEventsSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) + blocksSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) require.NoError(t, err) done := make(chan struct{}) go func() { - e := <-txEventsSub.Out() - edt := e.Msg.(EventDataNewBlock) + mt := <-blocksSub.Out() + edt := mt.Msg().(EventDataNewBlock) assert.Equal(t, block, edt.Block) assert.Equal(t, resultBeginBlock, edt.ResultBeginBlock) assert.Equal(t, resultEndBlock, edt.ResultEndBlock) @@ -106,13 +106,13 @@ func TestEventBusPublishEventNewBlockHeader(t *testing.T) { // PublishEventNewBlockHeader adds the tm.event tag, so the query below should work query := "tm.event='NewBlockHeader' AND baz=1 AND foz=2" - txEventsSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) + headersSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.MustParse(query)) require.NoError(t, err) done := make(chan struct{}) go func() { - e := <-txEventsSub.Out() - edt := e.Msg.(EventDataNewBlockHeader) + mt := <-headersSub.Out() + edt := mt.Msg().(EventDataNewBlockHeader) assert.Equal(t, block.Header, edt.Header) assert.Equal(t, resultBeginBlock, edt.ResultBeginBlock) assert.Equal(t, resultEndBlock, edt.ResultEndBlock) @@ -139,14 +139,14 @@ func TestEventBusPublish(t *testing.T) { require.NoError(t, err) defer eventBus.Stop() - eventsSub, err := eventBus.Subscribe(context.Background(), "test", tmquery.Empty{}) + sub, err := eventBus.Subscribe(context.Background(), "test", tmquery.Empty{}) require.NoError(t, err) const numEventsExpected = 14 done := make(chan struct{}) go func() { numEvents := 0 - for range eventsSub.Out() { + for range sub.Out() { numEvents++ if numEvents >= numEventsExpected { close(done)