diff --git a/abci/client/grpc_client.go b/abci/client/grpc_client.go index ef88736ab..0a4bdab34 100644 --- a/abci/client/grpc_client.go +++ b/abci/client/grpc_client.go @@ -85,12 +85,19 @@ func (cli *grpcClient) OnStart(ctx context.Context) error { cb(reqres.Response) } } - for reqres := range cli.chReqRes { - if reqres != nil { - callCb(reqres) - } else { - cli.Logger.Error("Received nil reqres") + + for { + select { + case reqres := <-cli.chReqRes: + if reqres != nil { + callCb(reqres) + } else { + cli.Logger.Error("Received nil reqres") + } + case <-ctx.Done(): + return } + } }() diff --git a/abci/server/grpc_server.go b/abci/server/grpc_server.go index 78da22cdb..b4e314182 100644 --- a/abci/server/grpc_server.go +++ b/abci/server/grpc_server.go @@ -50,6 +50,11 @@ func (s *GRPCServer) OnStart(ctx context.Context) error { s.Logger.Info("Listening", "proto", s.proto, "addr", s.addr) go func() { + go func() { + <-ctx.Done() + s.server.GracefulStop() + }() + if err := s.server.Serve(s.listener); err != nil { s.Logger.Error("Error serving gRPC server", "err", err) } diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index cda123d7f..8bd67fbda 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -50,6 +50,9 @@ func main() { ) flag.Parse() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger.Info( "Starting private validator", "addr", *addr, @@ -131,7 +134,7 @@ func main() { } // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(logger, func() { + tmos.TrapSignal(ctx, logger, func() { logger.Debug("SignerServer: calling Close") if *prometheusAddr != "" { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) diff --git a/cmd/tendermint/commands/inspect.go b/cmd/tendermint/commands/inspect.go index 3cd6ef572..bb6c5c2f1 100644 --- a/cmd/tendermint/commands/inspect.go +++ b/cmd/tendermint/commands/inspect.go @@ -1,8 +1,6 @@ package commands import ( - "context" - "os" "os/signal" "syscall" @@ -40,16 +38,9 @@ func init() { } func runInspect(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGTERM, syscall.SIGINT) defer cancel() - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGTERM, syscall.SIGINT) - go func() { - <-c - cancel() - }() - ins, err := inspect.NewFromConfig(logger, config) if err != nil { return err diff --git a/cmd/tendermint/commands/light.go b/cmd/tendermint/commands/light.go index 0e1894ccf..dc466ea0e 100644 --- a/cmd/tendermint/commands/light.go +++ b/cmd/tendermint/commands/light.go @@ -189,7 +189,7 @@ func runProxy(cmd *cobra.Command, args []string) error { } // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(logger, func() { + tmos.TrapSignal(cmd.Context(), logger, func() { p.Listener.Close() }) diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index 5fe8b2123..994c01991 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -161,11 +161,11 @@ func (r *Reactor) OnStart(ctx context.Context) error { go r.requestRoutine(ctx) r.poolWG.Add(1) - go r.poolRoutine(false) + go r.poolRoutine(ctx, false) } - go r.processBlockSyncCh() - go r.processPeerUpdates() + go r.processBlockSyncCh(ctx) + go r.processPeerUpdates(ctx) return nil } @@ -186,10 +186,6 @@ func (r *Reactor) OnStop() { // p2p Channels should execute Close(). close(r.closeCh) - // Wait for all p2p Channels to be closed before returning. This ensures we - // can easily reason about synchronization of all p2p Channels and ensure no - // panics will occur. - <-r.blockSyncCh.Done() <-r.peerUpdates.Done() } @@ -293,11 +289,11 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err // message execution will result in a PeerError being sent on the BlockSyncChannel. // When the reactor is stopped, we will catch the signal and close the p2p Channel // gracefully. -func (r *Reactor) processBlockSyncCh() { - defer r.blockSyncCh.Close() - +func (r *Reactor) processBlockSyncCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.blockSyncCh.In: if err := r.handleMessage(r.blockSyncCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.blockSyncCh.ID, "envelope", envelope, "err", err) @@ -346,11 +342,13 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates() { +func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() for { select { + case <-ctx.Done(): + return case peerUpdate := <-r.peerUpdates.Updates(): r.processPeerUpdate(peerUpdate) @@ -378,7 +376,7 @@ func (r *Reactor) SwitchToBlockSync(ctx context.Context, state sm.State) error { go r.requestRoutine(ctx) r.poolWG.Add(1) - go r.poolRoutine(true) + go r.poolRoutine(ctx, true) return nil } @@ -415,31 +413,23 @@ func (r *Reactor) requestRoutine(ctx context.Context) { go func() { defer r.poolWG.Done() - r.blockSyncOutBridgeCh <- p2p.Envelope{ + select { + case r.blockSyncOutBridgeCh <- p2p.Envelope{ Broadcast: true, Message: &bcproto.StatusRequest{}, + }: + case <-ctx.Done(): } }() } } } -func (r *Reactor) stopCtx() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - - go func() { - <-r.closeCh - cancel() - }() - - return ctx -} - // poolRoutine handles messages from the poolReactor telling the reactor what to // do. // // NOTE: Don't sleep in the FOR_LOOP or otherwise slow it down! -func (r *Reactor) poolRoutine(stateSynced bool) { +func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool) { var ( trySyncTicker = time.NewTicker(trySyncIntervalMS * time.Millisecond) switchToConsensusTicker = time.NewTicker(switchToConsensusIntervalSeconds * time.Second) @@ -453,7 +443,6 @@ func (r *Reactor) poolRoutine(stateSynced bool) { lastRate = 0.0 didProcessCh = make(chan struct{}, 1) - ctx = r.stopCtx() ) defer trySyncTicker.Stop() @@ -605,6 +594,8 @@ FOR_LOOP: continue FOR_LOOP + case <-ctx.Done(): + break FOR_LOOP case <-r.closeCh: break FOR_LOOP case <-r.pool.exitedCh: diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index c5f76066a..5345cb5c4 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -71,7 +71,7 @@ func setup( } chDesc := &p2p.ChannelDescriptor{ID: BlockSyncChannel, MessageType: new(bcproto.Message)} - rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) + rts.blockSyncChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) i := 0 for nodeID := range rts.network.Nodes { @@ -166,7 +166,7 @@ func (rts *reactorTestSuite) addNode( rts.peerChans[nodeID] = make(chan p2p.PeerUpdate) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) - rts.network.Nodes[nodeID].PeerManager.Register(rts.peerUpdates[nodeID]) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) rts.reactors[nodeID], err = NewReactor( rts.logger.With("nodeID", nodeID), state.Copy(), @@ -183,9 +183,9 @@ func (rts *reactorTestSuite) addNode( require.True(t, rts.reactors[nodeID].IsRunning()) } -func (rts *reactorTestSuite) start(t *testing.T) { +func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { t.Helper() - rts.network.Start(t) + rts.network.Start(ctx, t) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), len(rts.nodes)-1, @@ -207,7 +207,7 @@ func TestReactor_AbruptDisconnect(t *testing.T) { require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) - rts.start(t) + rts.start(ctx, t) secondaryPool := rts.reactors[rts.nodes[1]].pool @@ -244,7 +244,7 @@ func TestReactor_SyncTime(t *testing.T) { rts := setup(ctx, t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) - rts.start(t) + rts.start(ctx, t) require.Eventually( t, @@ -274,7 +274,7 @@ func TestReactor_NoBlockResponse(t *testing.T) { require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) - rts.start(t) + rts.start(ctx, t) testCases := []struct { height int64 @@ -325,7 +325,7 @@ func TestReactor_BadBlockStopsPeer(t *testing.T) { require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) - rts.start(t) + rts.start(ctx, t) require.Eventually( t, diff --git a/internal/consensus/byzantine_test.go b/internal/consensus/byzantine_test.go index 4c9ccbb2e..b53f3181d 100644 --- a/internal/consensus/byzantine_test.go +++ b/internal/consensus/byzantine_test.go @@ -123,7 +123,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { bzReactor := rts.reactors[bzNodeID] // alter prevote so that the byzantine node double votes when height is 2 - bzNodeState.doPrevote = func(height int64, round int32) { + bzNodeState.doPrevote = func(ctx context.Context, height int64, round int32) { // allow first height to happen normally so that byzantine validator is no longer proposer if height == prevoteHeight { prevote1, err := bzNodeState.signVote( @@ -161,7 +161,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { } } else { bzNodeState.Logger.Info("behaving normally") - bzNodeState.defaultDoPrevote(height, round) + bzNodeState.defaultDoPrevote(ctx, height, round) } } @@ -218,10 +218,12 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { proposal.Signature = p.Signature // send proposal and block parts on internal msg queue - lazyNodeState.sendInternalMessage(msgInfo{&ProposalMessage{proposal}, ""}) + lazyNodeState.sendInternalMessage(ctx, msgInfo{&ProposalMessage{proposal}, ""}) for i := 0; i < int(blockParts.Total()); i++ { part := blockParts.GetPart(i) - lazyNodeState.sendInternalMessage(msgInfo{&BlockPartMessage{lazyNodeState.Height, lazyNodeState.Round, part}, ""}) + lazyNodeState.sendInternalMessage(ctx, msgInfo{&BlockPartMessage{ + lazyNodeState.Height, lazyNodeState.Round, part, + }, ""}) } lazyNodeState.Logger.Info("Signed proposal", "height", height, "round", round, "proposal", proposal) lazyNodeState.Logger.Debug(fmt.Sprintf("Signed proposal block: %v", block)) diff --git a/internal/consensus/invalid_test.go b/internal/consensus/invalid_test.go index fc872a9fa..053f92264 100644 --- a/internal/consensus/invalid_test.go +++ b/internal/consensus/invalid_test.go @@ -50,7 +50,7 @@ func TestReactorInvalidPrecommit(t *testing.T) { // block and otherwise disable the priv validator. byzState.mtx.Lock() privVal := byzState.privValidator - byzState.doPrevote = func(height int64, round int32) { + byzState.doPrevote = func(ctx context.Context, height int64, round int32) { invalidDoPrevoteFunc(ctx, t, height, round, byzState, byzReactor, privVal) } byzState.mtx.Unlock() diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 4e2b1ad3a..5f9a7ff73 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -183,7 +183,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { // // TODO: Evaluate if we need this to be synchronized via WaitGroup as to not // leak the goroutine when stopping the reactor. - go r.peerStatsRoutine() + go r.peerStatsRoutine(ctx) r.subscribeToBroadcastEvents() @@ -193,11 +193,11 @@ func (r *Reactor) OnStart(ctx context.Context) error { } } - go r.processStateCh() - go r.processDataCh() - go r.processVoteCh() - go r.processVoteSetBitsCh() - go r.processPeerUpdates() + go r.processStateCh(ctx) + go r.processDataCh(ctx) + go r.processVoteCh(ctx) + go r.processVoteSetBitsCh(ctx) + go r.processPeerUpdates(ctx) return nil } @@ -231,18 +231,11 @@ func (r *Reactor) OnStop() { // Close the StateChannel goroutine separately since it uses its own channel // to signal closure. close(r.stateCloseCh) - <-r.stateCh.Done() // Close closeCh to signal to all spawned goroutines to gracefully exit. All // p2p Channels should execute Close(). close(r.closeCh) - // Wait for all p2p Channels to be closed before returning. This ensures we - // can easily reason about synchronization of all p2p Channels and ensure no - // panics will occur. - <-r.voteSetBitsCh.Done() - <-r.dataCh.Done() - <-r.voteCh.Done() <-r.peerUpdates.Done() } @@ -430,12 +423,15 @@ func makeRoundStepMessage(rs *cstypes.RoundState) *tmcons.NewRoundStep { } } -func (r *Reactor) sendNewRoundStepMessage(peerID types.NodeID) { +func (r *Reactor) sendNewRoundStepMessage(ctx context.Context, peerID types.NodeID) { rs := r.state.GetRoundState() msg := makeRoundStepMessage(rs) - r.stateCh.Out <- p2p.Envelope{ + select { + case <-ctx.Done(): + case r.stateCh.Out <- p2p.Envelope{ To: peerID, Message: msg, + }: } } @@ -503,11 +499,14 @@ func (r *Reactor) gossipDataForCatchup(rs *cstypes.RoundState, prs *cstypes.Peer time.Sleep(r.state.config.PeerGossipSleepDuration) } -func (r *Reactor) gossipDataRoutine(ps *PeerState) { +func (r *Reactor) gossipDataRoutine(ctx context.Context, ps *PeerState) { logger := r.Logger.With("peer", ps.peerID) defer ps.broadcastWG.Done() + timer := time.NewTimer(0) + defer timer.Stop() + OUTER_LOOP: for { if !r.IsRunning() { @@ -515,6 +514,8 @@ OUTER_LOOP: } select { + case <-ctx.Done(): + return case <-ps.closer.Done(): // The peer is marked for removal via a PeerUpdate as the doneCh was // explicitly closed to signal we should exit. @@ -566,7 +567,13 @@ OUTER_LOOP: "blockstoreBase", blockStoreBase, "blockstoreHeight", r.state.blockStore.Height(), ) - time.Sleep(r.state.config.PeerGossipSleepDuration) + + timer.Reset(r.state.config.PeerGossipSleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } } else { ps.InitProposalBlockParts(blockMeta.BlockID.PartSetHeader) } @@ -582,7 +589,12 @@ OUTER_LOOP: // if height and round don't match, sleep if (rs.Height != prs.Height) || (rs.Round != prs.Round) { - time.Sleep(r.state.config.PeerGossipSleepDuration) + timer.Reset(r.state.config.PeerGossipSleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } continue OUTER_LOOP } @@ -633,21 +645,29 @@ OUTER_LOOP: } // nothing to do -- sleep - time.Sleep(r.state.config.PeerGossipSleepDuration) + timer.Reset(r.state.config.PeerGossipSleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } continue OUTER_LOOP } } // pickSendVote picks a vote and sends it to the peer. It will return true if // there is a vote to send and false otherwise. -func (r *Reactor) pickSendVote(ps *PeerState, votes types.VoteSetReader) bool { +func (r *Reactor) pickSendVote(ctx context.Context, ps *PeerState, votes types.VoteSetReader) bool { if vote, ok := ps.PickVoteToSend(votes); ok { r.Logger.Debug("sending vote message", "ps", ps, "vote", vote) - r.voteCh.Out <- p2p.Envelope{ + select { + case <-ctx.Done(): + case r.voteCh.Out <- p2p.Envelope{ To: ps.peerID, Message: &tmcons.Vote{ Vote: vote.ToProto(), }, + }: } ps.SetHasVote(vote) @@ -657,12 +677,17 @@ func (r *Reactor) pickSendVote(ps *PeerState, votes types.VoteSetReader) bool { return false } -func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.PeerRoundState, ps *PeerState) bool { +func (r *Reactor) gossipVotesForHeight( + ctx context.Context, + rs *cstypes.RoundState, + prs *cstypes.PeerRoundState, + ps *PeerState, +) bool { logger := r.Logger.With("height", prs.Height).With("peer", ps.peerID) // if there are lastCommits to send... if prs.Step == cstypes.RoundStepNewHeight { - if r.pickSendVote(ps, rs.LastCommit) { + if r.pickSendVote(ctx, ps, rs.LastCommit) { logger.Debug("picked rs.LastCommit to send") return true } @@ -671,7 +696,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer // if there are POL prevotes to send... if prs.Step <= cstypes.RoundStepPropose && prs.Round != -1 && prs.Round <= rs.Round && prs.ProposalPOLRound != -1 { if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil { - if r.pickSendVote(ps, polPrevotes) { + if r.pickSendVote(ctx, ps, polPrevotes) { logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound) return true } @@ -680,7 +705,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer // if there are prevotes to send... if prs.Step <= cstypes.RoundStepPrevoteWait && prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ps, rs.Votes.Prevotes(prs.Round)) { + if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) { logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round) return true } @@ -688,7 +713,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer // if there are precommits to send... if prs.Step <= cstypes.RoundStepPrecommitWait && prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ps, rs.Votes.Precommits(prs.Round)) { + if r.pickSendVote(ctx, ps, rs.Votes.Precommits(prs.Round)) { logger.Debug("picked rs.Precommits(prs.Round) to send", "round", prs.Round) return true } @@ -696,7 +721,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer // if there are prevotes to send...(which are needed because of validBlock mechanism) if prs.Round != -1 && prs.Round <= rs.Round { - if r.pickSendVote(ps, rs.Votes.Prevotes(prs.Round)) { + if r.pickSendVote(ctx, ps, rs.Votes.Prevotes(prs.Round)) { logger.Debug("picked rs.Prevotes(prs.Round) to send", "round", prs.Round) return true } @@ -705,7 +730,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer // if there are POLPrevotes to send... if prs.ProposalPOLRound != -1 { if polPrevotes := rs.Votes.Prevotes(prs.ProposalPOLRound); polPrevotes != nil { - if r.pickSendVote(ps, polPrevotes) { + if r.pickSendVote(ctx, ps, polPrevotes) { logger.Debug("picked rs.Prevotes(prs.ProposalPOLRound) to send", "round", prs.ProposalPOLRound) return true } @@ -715,7 +740,7 @@ func (r *Reactor) gossipVotesForHeight(rs *cstypes.RoundState, prs *cstypes.Peer return false } -func (r *Reactor) gossipVotesRoutine(ps *PeerState) { +func (r *Reactor) gossipVotesRoutine(ctx context.Context, ps *PeerState) { logger := r.Logger.With("peer", ps.peerID) defer ps.broadcastWG.Done() @@ -723,6 +748,9 @@ func (r *Reactor) gossipVotesRoutine(ps *PeerState) { // XXX: simple hack to throttle logs upon sleep logThrottle := 0 + timer := time.NewTimer(0) + defer timer.Stop() + OUTER_LOOP: for { if !r.IsRunning() { @@ -730,6 +758,8 @@ OUTER_LOOP: } select { + case <-ctx.Done(): + return case <-ps.closer.Done(): // The peer is marked for removal via a PeerUpdate as the doneCh was // explicitly closed to signal we should exit. @@ -750,14 +780,14 @@ OUTER_LOOP: // if height matches, then send LastCommit, Prevotes, and Precommits if rs.Height == prs.Height { - if r.gossipVotesForHeight(rs, prs, ps) { + if r.gossipVotesForHeight(ctx, rs, prs, ps) { continue OUTER_LOOP } } // special catchup logic -- if peer is lagging by height 1, send LastCommit if prs.Height != 0 && rs.Height == prs.Height+1 { - if r.pickSendVote(ps, rs.LastCommit) { + if r.pickSendVote(ctx, ps, rs.LastCommit) { logger.Debug("picked rs.LastCommit to send", "height", prs.Height) continue OUTER_LOOP } @@ -769,7 +799,7 @@ OUTER_LOOP: // Load the block commit for prs.Height, which contains precommit // signatures for prs.Height. if commit := r.state.blockStore.LoadBlockCommit(prs.Height); commit != nil { - if r.pickSendVote(ps, commit) { + if r.pickSendVote(ctx, ps, commit) { logger.Debug("picked Catchup commit to send", "height", prs.Height) continue OUTER_LOOP } @@ -790,16 +820,24 @@ OUTER_LOOP: logThrottle = 1 } - time.Sleep(r.state.config.PeerGossipSleepDuration) + timer.Reset(r.state.config.PeerGossipSleepDuration) + select { + case <-ctx.Done(): + return + case <-timer.C: + } continue OUTER_LOOP } } // NOTE: `queryMaj23Routine` has a simple crude design since it only comes // into play for liveness when there's a signature DDoS attack happening. -func (r *Reactor) queryMaj23Routine(ps *PeerState) { +func (r *Reactor) queryMaj23Routine(ctx context.Context, ps *PeerState) { defer ps.broadcastWG.Done() + timer := time.NewTimer(0) + defer timer.Stop() + OUTER_LOOP: for { if !r.IsRunning() { @@ -807,11 +845,12 @@ OUTER_LOOP: } select { + case <-ctx.Done(): + return case <-ps.closer.Done(): // The peer is marked for removal via a PeerUpdate as the doneCh was // explicitly closed to signal we should exit. return - default: } @@ -832,7 +871,12 @@ OUTER_LOOP: }, } - time.Sleep(r.state.config.PeerQueryMaj23SleepDuration) + timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } } } } @@ -854,7 +898,12 @@ OUTER_LOOP: }, } - time.Sleep(r.state.config.PeerQueryMaj23SleepDuration) + select { + case <-timer.C: + timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) + case <-ctx.Done(): + return + } } } } @@ -876,7 +925,12 @@ OUTER_LOOP: }, } - time.Sleep(r.state.config.PeerQueryMaj23SleepDuration) + timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } } } } @@ -901,12 +955,23 @@ OUTER_LOOP: }, } - time.Sleep(r.state.config.PeerQueryMaj23SleepDuration) + timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } } } } - time.Sleep(r.state.config.PeerQueryMaj23SleepDuration) + timer.Reset(r.state.config.PeerQueryMaj23SleepDuration) + select { + case <-timer.C: + case <-ctx.Done(): + return + } + continue OUTER_LOOP } } @@ -916,7 +981,7 @@ OUTER_LOOP: // be the case, and we spawn all the relevant goroutine to broadcast messages to // the peer. During peer removal, we remove the peer for our set of peers and // signal to all spawned goroutines to gracefully exit in a non-blocking manner. -func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { +func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { r.Logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) r.mtx.Lock() @@ -952,14 +1017,14 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { ps.SetRunning(true) // start goroutines for this peer - go r.gossipDataRoutine(ps) - go r.gossipVotesRoutine(ps) - go r.queryMaj23Routine(ps) + go r.gossipDataRoutine(ctx, ps) + go r.gossipVotesRoutine(ctx, ps) + go r.queryMaj23Routine(ctx, ps) // Send our state to the peer. If we're block-syncing, broadcast a // RoundStepMessage later upon SwitchToConsensus(). if !r.waitSync { - go r.sendNewRoundStepMessage(ps.peerID) + go r.sendNewRoundStepMessage(ctx, ps.peerID) } } @@ -1266,11 +1331,11 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err // execution will result in a PeerError being sent on the StateChannel. When // the reactor is stopped, we will catch the signal and close the p2p Channel // gracefully. -func (r *Reactor) processStateCh() { - defer r.stateCh.Close() - +func (r *Reactor) processStateCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.stateCh.In: if err := r.handleMessage(r.stateCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.stateCh.ID, "envelope", envelope, "err", err) @@ -1292,11 +1357,11 @@ func (r *Reactor) processStateCh() { // execution will result in a PeerError being sent on the DataChannel. When // the reactor is stopped, we will catch the signal and close the p2p Channel // gracefully. -func (r *Reactor) processDataCh() { - defer r.dataCh.Close() - +func (r *Reactor) processDataCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.dataCh.In: if err := r.handleMessage(r.dataCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.dataCh.ID, "envelope", envelope, "err", err) @@ -1318,11 +1383,11 @@ func (r *Reactor) processDataCh() { // execution will result in a PeerError being sent on the VoteChannel. When // the reactor is stopped, we will catch the signal and close the p2p Channel // gracefully. -func (r *Reactor) processVoteCh() { - defer r.voteCh.Close() - +func (r *Reactor) processVoteCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.voteCh.In: if err := r.handleMessage(r.voteCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.voteCh.ID, "envelope", envelope, "err", err) @@ -1344,11 +1409,11 @@ func (r *Reactor) processVoteCh() { // execution will result in a PeerError being sent on the VoteSetBitsChannel. // When the reactor is stopped, we will catch the signal and close the p2p // Channel gracefully. -func (r *Reactor) processVoteSetBitsCh() { - defer r.voteSetBitsCh.Close() - +func (r *Reactor) processVoteSetBitsCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.voteSetBitsCh.In: if err := r.handleMessage(r.voteSetBitsCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.voteSetBitsCh.ID, "envelope", envelope, "err", err) @@ -1368,13 +1433,15 @@ func (r *Reactor) processVoteSetBitsCh() { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates() { +func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() for { select { + case <-ctx.Done(): + return case peerUpdate := <-r.peerUpdates.Updates(): - r.processPeerUpdate(peerUpdate) + r.processPeerUpdate(ctx, peerUpdate) case <-r.closeCh: r.Logger.Debug("stopped listening on peer updates channel; closing...") @@ -1383,7 +1450,7 @@ func (r *Reactor) processPeerUpdates() { } } -func (r *Reactor) peerStatsRoutine() { +func (r *Reactor) peerStatsRoutine(ctx context.Context) { for { if !r.IsRunning() { r.Logger.Info("stopping peerStatsRoutine") @@ -1415,6 +1482,8 @@ func (r *Reactor) peerStatsRoutine() { }) } } + case <-ctx.Done(): + return case <-r.closeCh: return } diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 9e7c498cd..f0bb8b53c 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -76,10 +76,10 @@ func setup( blocksyncSubs: make(map[types.NodeID]eventbus.Subscription, numNodes), } - rts.stateChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(StateChannel, size)) - rts.dataChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(DataChannel, size)) - rts.voteChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteChannel, size)) - rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteSetBitsChannel, size)) + rts.stateChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(StateChannel, size)) + rts.dataChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(DataChannel, size)) + rts.voteChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteChannel, size)) + rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc(VoteSetBitsChannel, size)) ctx, cancel := context.WithCancel(ctx) // Canceled during cleanup (see below). @@ -134,7 +134,7 @@ func setup( require.Len(t, rts.reactors, numNodes) // start the in-memory network and connect all peers with each other - rts.network.Start(t) + rts.network.Start(ctx, t) t.Cleanup(func() { cancel() diff --git a/internal/consensus/replay_file.go b/internal/consensus/replay_file.go index ef5e88730..d54c77fe3 100644 --- a/internal/consensus/replay_file.go +++ b/internal/consensus/replay_file.go @@ -183,17 +183,6 @@ func (pb *playback) replayReset(ctx context.Context, count int, newStepSub event func (cs *State) startForReplay() { cs.Logger.Error("Replay commands are disabled until someone updates them and writes tests") - /* TODO:! - // since we replay tocks we just ignore ticks - go func() { - for { - select { - case <-cs.tickChan: - case <-cs.Quit: - return - } - } - }()*/ } // console function for parsing input and running commands. The integer diff --git a/internal/consensus/state.go b/internal/consensus/state.go index f71d08649..2c358a213 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -131,7 +131,7 @@ type State struct { // some functions can be overwritten for testing decideProposal func(height int64, round int32) - doPrevote func(height int64, round int32) + doPrevote func(ctx context.Context, height int64, round int32) setProposal func(proposal *types.Proposal) error // closed when we finish shutting down @@ -594,8 +594,9 @@ func (cs *State) scheduleTimeout(duration time.Duration, height int64, round int } // send a msg into the receiveRoutine regarding our own proposal, block part, or vote -func (cs *State) sendInternalMessage(mi msgInfo) { +func (cs *State) sendInternalMessage(ctx context.Context, mi msgInfo) { select { + case <-ctx.Done(): case cs.internalMsgQueue <- mi: default: // NOTE: using the go-routine means our votes can @@ -603,7 +604,12 @@ func (cs *State) sendInternalMessage(mi msgInfo) { // TODO: use CList here for strict determinism and // attempt push to internalMsgQueue in receiveRoutine cs.Logger.Debug("internal msg queue is full; using a go-routine") - go func() { cs.internalMsgQueue <- mi }() + go func() { + select { + case <-ctx.Done(): + case cs.internalMsgQueue <- mi: + } + }() } } @@ -1219,11 +1225,11 @@ func (cs *State) defaultDecideProposal(height int64, round int32) { proposal.Signature = p.Signature // send proposal and block parts on internal msg queue - cs.sendInternalMessage(msgInfo{&ProposalMessage{proposal}, ""}) + cs.sendInternalMessage(ctx, msgInfo{&ProposalMessage{proposal}, ""}) for i := 0; i < int(blockParts.Total()); i++ { part := blockParts.GetPart(i) - cs.sendInternalMessage(msgInfo{&BlockPartMessage{cs.Height, cs.Round, part}, ""}) + cs.sendInternalMessage(ctx, msgInfo{&BlockPartMessage{cs.Height, cs.Round, part}, ""}) } cs.Logger.Debug("signed proposal", "height", height, "round", round, "proposal", proposal) @@ -1312,26 +1318,26 @@ func (cs *State) enterPrevote(ctx context.Context, height int64, round int32) { logger.Debug("entering prevote step", "current", fmt.Sprintf("%v/%v/%v", cs.Height, cs.Round, cs.Step)) // Sign and broadcast vote as necessary - cs.doPrevote(height, round) + cs.doPrevote(ctx, height, round) // Once `addVote` hits any +2/3 prevotes, we will go to PrevoteWait // (so we have more time to try and collect +2/3 prevotes for a single block) } -func (cs *State) defaultDoPrevote(height int64, round int32) { +func (cs *State) defaultDoPrevote(ctx context.Context, height int64, round int32) { logger := cs.Logger.With("height", height, "round", round) // If a block is locked, prevote that. if cs.LockedBlock != nil { logger.Debug("prevote step; already locked on a block; prevoting locked block") - cs.signAddVote(tmproto.PrevoteType, cs.LockedBlock.Hash(), cs.LockedBlockParts.Header()) + cs.signAddVote(ctx, tmproto.PrevoteType, cs.LockedBlock.Hash(), cs.LockedBlockParts.Header()) return } // If ProposalBlock is nil, prevote nil. if cs.ProposalBlock == nil { logger.Debug("prevote step: ProposalBlock is nil") - cs.signAddVote(tmproto.PrevoteType, nil, types.PartSetHeader{}) + cs.signAddVote(ctx, tmproto.PrevoteType, nil, types.PartSetHeader{}) return } @@ -1340,7 +1346,7 @@ func (cs *State) defaultDoPrevote(height int64, round int32) { if err != nil { // ProposalBlock is invalid, prevote nil. logger.Error("prevote step: ProposalBlock is invalid", "err", err) - cs.signAddVote(tmproto.PrevoteType, nil, types.PartSetHeader{}) + cs.signAddVote(ctx, tmproto.PrevoteType, nil, types.PartSetHeader{}) return } @@ -1348,7 +1354,7 @@ func (cs *State) defaultDoPrevote(height int64, round int32) { // NOTE: the proposal signature is validated when it is received, // and the proposal block parts are validated as they are received (against the merkle hash in the proposal) logger.Debug("prevote step: ProposalBlock is valid") - cs.signAddVote(tmproto.PrevoteType, cs.ProposalBlock.Hash(), cs.ProposalBlockParts.Header()) + cs.signAddVote(ctx, tmproto.PrevoteType, cs.ProposalBlock.Hash(), cs.ProposalBlockParts.Header()) } // Enter: any +2/3 prevotes at next round. @@ -1418,7 +1424,7 @@ func (cs *State) enterPrecommit(ctx context.Context, height int64, round int32) logger.Debug("precommit step; no +2/3 prevotes during enterPrecommit; precommitting nil") } - cs.signAddVote(tmproto.PrecommitType, nil, types.PartSetHeader{}) + cs.signAddVote(ctx, tmproto.PrecommitType, nil, types.PartSetHeader{}) return } @@ -1448,7 +1454,7 @@ func (cs *State) enterPrecommit(ctx context.Context, height int64, round int32) } } - cs.signAddVote(tmproto.PrecommitType, nil, types.PartSetHeader{}) + cs.signAddVote(ctx, tmproto.PrecommitType, nil, types.PartSetHeader{}) return } @@ -1463,7 +1469,7 @@ func (cs *State) enterPrecommit(ctx context.Context, height int64, round int32) logger.Error("failed publishing event relock", "err", err) } - cs.signAddVote(tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader) + cs.signAddVote(ctx, tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader) return } @@ -1484,7 +1490,7 @@ func (cs *State) enterPrecommit(ctx context.Context, height int64, round int32) logger.Error("failed publishing event lock", "err", err) } - cs.signAddVote(tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader) + cs.signAddVote(ctx, tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader) return } @@ -1506,7 +1512,7 @@ func (cs *State) enterPrecommit(ctx context.Context, height int64, round int32) logger.Error("failed publishing event unlock", "err", err) } - cs.signAddVote(tmproto.PrecommitType, nil, types.PartSetHeader{}) + cs.signAddVote(ctx, tmproto.PrecommitType, nil, types.PartSetHeader{}) } // Enter: any +2/3 precommits for next round. @@ -2292,7 +2298,7 @@ func (cs *State) voteTime() time.Time { } // sign the vote and publish on internalMsgQueue -func (cs *State) signAddVote(msgType tmproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote { +func (cs *State) signAddVote(ctx context.Context, msgType tmproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote { if cs.privValidator == nil { // the node does not have a key return nil } @@ -2311,7 +2317,7 @@ func (cs *State) signAddVote(msgType tmproto.SignedMsgType, hash []byte, header // TODO: pass pubKey to signVote vote, err := cs.signVote(msgType, hash, header) if err == nil { - cs.sendInternalMessage(msgInfo{&VoteMessage{vote}, ""}) + cs.sendInternalMessage(ctx, msgInfo{&VoteMessage{vote}, ""}) cs.Logger.Debug("signed and pushed vote", "height", cs.Height, "round", cs.Round, "vote", vote) return vote } diff --git a/internal/consensus/ticker.go b/internal/consensus/ticker.go index 6e323b2d0..84814f33d 100644 --- a/internal/consensus/ticker.go +++ b/internal/consensus/ticker.go @@ -122,7 +122,12 @@ func (t *timeoutTicker) timeoutRoutine(ctx context.Context) { // Determinism comes from playback in the receiveRoutine. // We can eliminate it by merging the timeoutRoutine into receiveRoutine // and managing the timeouts ourselves with a millisecond ticker - go func(toi timeoutInfo) { t.tockChan <- toi }(ti) + go func(toi timeoutInfo) { + select { + case t.tockChan <- toi: + case <-ctx.Done(): + } + }(ti) case <-ctx.Done(): return } diff --git a/internal/evidence/reactor.go b/internal/evidence/reactor.go index 89f60b749..5b093dc15 100644 --- a/internal/evidence/reactor.go +++ b/internal/evidence/reactor.go @@ -83,8 +83,8 @@ func NewReactor( // messages on that p2p channel accordingly. The caller must be sure to execute // OnStop to ensure the outbound p2p Channels are closed. No error is returned. func (r *Reactor) OnStart(ctx context.Context) error { - go r.processEvidenceCh() - go r.processPeerUpdates() + go r.processEvidenceCh(ctx) + go r.processPeerUpdates(ctx) return nil } @@ -109,7 +109,6 @@ func (r *Reactor) OnStop() { // Wait for all p2p Channels to be closed before returning. This ensures we // can easily reason about synchronization of all p2p Channels and ensure no // panics will occur. - <-r.evidenceCh.Done() <-r.peerUpdates.Done() // Close the evidence db @@ -183,11 +182,11 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err // processEvidenceCh implements a blocking event loop where we listen for p2p // Envelope messages from the evidenceCh. -func (r *Reactor) processEvidenceCh() { - defer r.evidenceCh.Close() - +func (r *Reactor) processEvidenceCh(ctx context.Context) { for { select { + case <-ctx.Done(): + return case envelope := <-r.evidenceCh.In: if err := r.handleMessage(r.evidenceCh.ID, envelope); err != nil { r.Logger.Error("failed to process message", "ch_id", r.evidenceCh.ID, "envelope", envelope, "err", err) @@ -215,7 +214,7 @@ func (r *Reactor) processEvidenceCh() { // connects/disconnects frequently from the broadcasting peer(s). // // REF: https://github.com/tendermint/tendermint/issues/4727 -func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { +func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { r.Logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) r.mtx.Lock() @@ -241,7 +240,7 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { r.peerRoutines[peerUpdate.NodeID] = closer r.peerWG.Add(1) - go r.broadcastEvidenceLoop(peerUpdate.NodeID, closer) + go r.broadcastEvidenceLoop(ctx, peerUpdate.NodeID, closer) } case p2p.PeerStatusDown: @@ -259,14 +258,15 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates() { +func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() for { select { case peerUpdate := <-r.peerUpdates.Updates(): - r.processPeerUpdate(peerUpdate) - + r.processPeerUpdate(ctx, peerUpdate) + case <-ctx.Done(): + return case <-r.closeCh: r.Logger.Debug("stopped listening on peer updates channel; closing...") return @@ -285,7 +285,7 @@ func (r *Reactor) processPeerUpdates() { // that the peer has already received or may not be ready for. // // REF: https://github.com/tendermint/tendermint/issues/4727 -func (r *Reactor) broadcastEvidenceLoop(peerID types.NodeID, closer *tmsync.Closer) { +func (r *Reactor) broadcastEvidenceLoop(ctx context.Context, peerID types.NodeID, closer *tmsync.Closer) { var next *clist.CElement defer func() { @@ -315,6 +315,8 @@ func (r *Reactor) broadcastEvidenceLoop(peerID types.NodeID, closer *tmsync.Clos continue } + case <-ctx.Done(): + return case <-closer.Done(): // The peer is marked for removal via a PeerUpdate as the doneCh was // explicitly closed to signal we should exit. @@ -337,11 +339,15 @@ func (r *Reactor) broadcastEvidenceLoop(peerID types.NodeID, closer *tmsync.Clos // and thus would not be able to process the evidence correctly. Also, the // peer may receive this piece of evidence multiple times if it added and // removed frequently from the broadcasting peer. - r.evidenceCh.Out <- p2p.Envelope{ + select { + case <-ctx.Done(): + return + case r.evidenceCh.Out <- p2p.Envelope{ To: peerID, Message: &tmproto.EvidenceList{ Evidence: []tmproto.Evidence{*evProto}, }, + }: } r.Logger.Debug("gossiped evidence to peer", "evidence", ev, "peer", peerID) diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index 7808e53fb..b30f9e9b1 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -64,7 +64,7 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint } chDesc := &p2p.ChannelDescriptor{ID: evidence.EvidenceChannel, MessageType: new(tmproto.EvidenceList)} - rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) + rts.evidenceChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), 0) idx := 0 @@ -86,7 +86,7 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint rts.peerChans[nodeID] = make(chan p2p.PeerUpdate) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) - rts.network.Nodes[nodeID].PeerManager.Register(rts.peerUpdates[nodeID]) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) rts.nodes = append(rts.nodes, rts.network.Nodes[nodeID]) rts.reactors[nodeID] = evidence.NewReactor(logger, @@ -114,8 +114,8 @@ func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint return rts } -func (rts *reactorTestSuite) start(t *testing.T) { - rts.network.Start(t) +func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { + rts.network.Start(ctx, t) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), rts.numStateStores-1, @@ -251,7 +251,7 @@ func TestReactorMultiDisconnect(t *testing.T) { require.Equal(t, primary.PeerManager.Status(secondary.NodeID), p2p.PeerStatusDown) - rts.start(t) + rts.start(ctx, t) require.Equal(t, primary.PeerManager.Status(secondary.NodeID), p2p.PeerStatusUp) // Ensure "disconnecting" the secondary peer from the primary more than once @@ -289,7 +289,7 @@ func TestReactorBroadcastEvidence(t *testing.T) { defer cancel() rts := setup(ctx, t, stateDBs, 0) - rts.start(t) + rts.start(ctx, t) // Create a series of fixtures where each suite contains a reactor and // evidence pool. In addition, we mark a primary suite and the rest are @@ -346,7 +346,7 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) { defer cancel() rts := setup(ctx, t, []sm.Store{stateDB1, stateDB2}, 100) - rts.start(t) + rts.start(ctx, t) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -396,7 +396,7 @@ func TestReactorBroadcastEvidence_Pending(t *testing.T) { // the secondary should have half the evidence as pending require.Equal(t, numEvidence/2, int(rts.pools[secondary.NodeID].Size())) - rts.start(t) + rts.start(ctx, t) // The secondary reactor should have received all the evidence ignoring the // already pending evidence. @@ -449,7 +449,7 @@ func TestReactorBroadcastEvidence_Committed(t *testing.T) { require.Equal(t, 0, int(rts.pools[secondary.NodeID].Size())) // start the network and ensure it's configured - rts.start(t) + rts.start(ctx, t) // The secondary reactor should have received all the evidence ignoring the // already committed evidence. @@ -480,7 +480,7 @@ func TestReactorBroadcastEvidence_FullyConnected(t *testing.T) { defer cancel() rts := setup(ctx, t, stateDBs, 0) - rts.start(t) + rts.start(ctx, t) evList := createEvidenceList(t, rts.pools[rts.network.RandomNode().NodeID], val, numEvidence) diff --git a/internal/inspect/inspect.go b/internal/inspect/inspect.go index eb75d94a3..6381ea888 100644 --- a/internal/inspect/inspect.go +++ b/internal/inspect/inspect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/http" "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/internal/eventbus" @@ -117,7 +118,7 @@ func startRPCServers(ctx context.Context, cfg *config.RPCConfig, logger log.Logg logger.Info("RPC HTTPS server starting", "address", listenerAddr, "certfile", certFile, "keyfile", keyFile) err := server.ListenAndServeTLS(tctx, certFile, keyFile) - if !errors.Is(err, net.ErrClosed) { + if !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { return err } logger.Info("RPC HTTPS server stopped", "address", listenerAddr) @@ -128,7 +129,7 @@ func startRPCServers(ctx context.Context, cfg *config.RPCConfig, logger log.Logg g.Go(func() error { logger.Info("RPC HTTP server starting", "address", listenerAddr) err := server.ListenAndServe(tctx) - if !errors.Is(err, net.ErrClosed) { + if !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { return err } logger.Info("RPC HTTP server stopped", "address", listenerAddr) diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index ff6ade0d0..a75777741 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "os" + "runtime" "strings" "sync" "testing" @@ -52,13 +53,14 @@ func TestInspectRun(t *testing.T) { logger := testLogger.With(t.Name()) d, err := inspect.NewFromConfig(logger, cfg) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) stoppedWG := &sync.WaitGroup{} stoppedWG.Add(1) go func() { + defer stoppedWG.Done() require.NoError(t, d.Run(ctx)) - stoppedWG.Done() }() + time.Sleep(100 * time.Millisecond) cancel() stoppedWG.Wait() }) @@ -88,16 +90,13 @@ func TestBlock(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(1) - startedWG := &sync.WaitGroup{} - startedWG.Add(1) go func() { - startedWG.Done() defer wg.Done() require.NoError(t, d.Run(ctx)) }() // FIXME: used to induce context switch. // Determine more deterministic method for prompting a context switch - startedWG.Wait() + runtime.Gosched() requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) diff --git a/internal/inspect/rpc/rpc.go b/internal/inspect/rpc/rpc.go index be83b1b1e..cfcddcd44 100644 --- a/internal/inspect/rpc/rpc.go +++ b/internal/inspect/rpc/rpc.go @@ -114,7 +114,8 @@ func (srv *Server) ListenAndServe(ctx context.Context) error { <-ctx.Done() listener.Close() }() - return server.Serve(listener, srv.Handler, srv.Logger, serverRPCConfig(srv.Config)) + + return server.Serve(ctx, listener, srv.Handler, srv.Logger, serverRPCConfig(srv.Config)) } // ListenAndServeTLS listens on the address specified in srv.Addr. ListenAndServeTLS handles @@ -128,7 +129,7 @@ func (srv *Server) ListenAndServeTLS(ctx context.Context, certFile, keyFile stri <-ctx.Done() listener.Close() }() - return server.ServeTLS(listener, srv.Handler, certFile, keyFile, srv.Logger, serverRPCConfig(srv.Config)) + return server.ServeTLS(ctx, listener, srv.Handler, certFile, keyFile, srv.Logger, serverRPCConfig(srv.Config)) } func serverRPCConfig(r *config.RPCConfig) *server.Config { diff --git a/internal/mempool/reactor.go b/internal/mempool/reactor.go index 2ddb44e7a..44f9c4df9 100644 --- a/internal/mempool/reactor.go +++ b/internal/mempool/reactor.go @@ -117,8 +117,8 @@ func (r *Reactor) OnStart(ctx context.Context) error { r.Logger.Info("tx broadcasting is disabled") } - go r.processMempoolCh() - go r.processPeerUpdates() + go r.processMempoolCh(ctx) + go r.processPeerUpdates(ctx) return nil } @@ -139,10 +139,6 @@ func (r *Reactor) OnStop() { // p2p Channels should execute Close(). close(r.closeCh) - // Wait for all p2p Channels to be closed before returning. This ensures we - // can easily reason about synchronization of all p2p Channels and ensure no - // panics will occur. - <-r.mempoolCh.Done() <-r.peerUpdates.Done() } @@ -209,9 +205,7 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err // processMempoolCh implements a blocking event loop where we listen for p2p // Envelope messages from the mempoolCh. -func (r *Reactor) processMempoolCh() { - defer r.mempoolCh.Close() - +func (r *Reactor) processMempoolCh(ctx context.Context) { for { select { case envelope := <-r.mempoolCh.In: @@ -222,7 +216,8 @@ func (r *Reactor) processMempoolCh() { Err: err, } } - + case <-ctx.Done(): + return case <-r.closeCh: r.Logger.Debug("stopped listening on mempool channel; closing...") return @@ -235,7 +230,7 @@ func (r *Reactor) processMempoolCh() { // goroutine or not. If not, we start one for the newly added peer. For down or // removed peers, we remove the peer from the mempool peer ID set and signal to // stop the tx broadcasting goroutine. -func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { +func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { r.Logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) r.mtx.Lock() @@ -266,7 +261,7 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { r.ids.ReserveForPeer(peerUpdate.NodeID) // start a broadcast routine ensuring all txs are forwarded to the peer - go r.broadcastTxRoutine(peerUpdate.NodeID, closer) + go r.broadcastTxRoutine(ctx, peerUpdate.NodeID, closer) } } @@ -287,13 +282,15 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates() { +func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() for { select { + case <-ctx.Done(): + return case peerUpdate := <-r.peerUpdates.Updates(): - r.processPeerUpdate(peerUpdate) + r.processPeerUpdate(ctx, peerUpdate) case <-r.closeCh: r.Logger.Debug("stopped listening on peer updates channel; closing...") @@ -302,7 +299,7 @@ func (r *Reactor) processPeerUpdates() { } } -func (r *Reactor) broadcastTxRoutine(peerID types.NodeID, closer *tmsync.Closer) { +func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID, closer *tmsync.Closer) { peerMempoolID := r.ids.GetForPeer(peerID) var nextGossipTx *clist.CElement @@ -325,7 +322,7 @@ func (r *Reactor) broadcastTxRoutine(peerID types.NodeID, closer *tmsync.Closer) }() for { - if !r.IsRunning() { + if !r.IsRunning() || ctx.Err() != nil { return } @@ -344,6 +341,9 @@ func (r *Reactor) broadcastTxRoutine(peerID types.NodeID, closer *tmsync.Closer) // explicitly closed to signal we should exit. return + case <-ctx.Done(): + return + case <-r.closeCh: // The reactor has signaled that we are stopped and thus we should // implicitly exit this peer's goroutine. @@ -367,11 +367,14 @@ func (r *Reactor) broadcastTxRoutine(peerID types.NodeID, closer *tmsync.Closer) if ok := r.mempool.txStore.TxHasPeer(memTx.hash, peerMempoolID); !ok { // Send the mempool tx to the corresponding peer. Note, the peer may be // behind and thus would not be able to process the mempool tx correctly. - r.mempoolCh.Out <- p2p.Envelope{ + select { + case r.mempoolCh.Out <- p2p.Envelope{ To: peerID, Message: &protomem.Txs{ Txs: [][]byte{memTx.tx}, }, + }: + case <-ctx.Done(): } r.Logger.Debug( "gossiped tx to peer", @@ -389,6 +392,9 @@ func (r *Reactor) broadcastTxRoutine(peerID types.NodeID, closer *tmsync.Closer) // explicitly closed to signal we should exit. return + case <-ctx.Done(): + return + case <-r.closeCh: // The reactor has signaled that we are stopped and thus we should // implicitly exit this peer's goroutine. diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index 62cacfd10..86a3b4db4 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/fortytw2/leaktest" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/abci/example/kvstore" abci "github.com/tendermint/tendermint/abci/types" @@ -55,7 +56,7 @@ func setupReactors(ctx context.Context, t *testing.T, numNodes int, chBuf uint) } chDesc := GetChannelDescriptor(cfg.Mempool) - rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(t, chDesc) + rts.mempoolChannels = rts.network.MakeChannelsNoCleanup(ctx, t, chDesc) for nodeID := range rts.network.Nodes { rts.kvstores[nodeID] = kvstore.NewApplication() @@ -65,7 +66,7 @@ func setupReactors(ctx context.Context, t *testing.T, numNodes int, chBuf uint) rts.peerChans[nodeID] = make(chan p2p.PeerUpdate) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) - rts.network.Nodes[nodeID].PeerManager.Register(rts.peerUpdates[nodeID]) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) rts.reactors[nodeID] = NewReactor( rts.logger.With("nodeID", nodeID), @@ -93,12 +94,14 @@ func setupReactors(ctx context.Context, t *testing.T, numNodes int, chBuf uint) } }) + t.Cleanup(leaktest.Check(t)) + return rts } -func (rts *reactorTestSuite) start(t *testing.T) { +func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { t.Helper() - rts.network.Start(t) + rts.network.Start(ctx, t) require.Len(t, rts.network.RandomNode().PeerManager.Peers(), len(rts.nodes)-1, @@ -108,21 +111,11 @@ func (rts *reactorTestSuite) start(t *testing.T) { func (rts *reactorTestSuite) assertMempoolChannelsDrained(t *testing.T) { t.Helper() - rts.stop(t) - for _, mch := range rts.mempoolChannels { require.Empty(t, mch.Out, "checking channel %q (len=%d)", mch.ID, len(mch.Out)) } } -func (rts *reactorTestSuite) stop(t *testing.T) { - for id, r := range rts.reactors { - require.NoError(t, r.Stop(), "stopping reactor %s", id) - r.Wait() - require.False(t, r.IsRunning(), "reactor %s did not stop", id) - } -} - func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...types.NodeID) { t.Helper() @@ -170,11 +163,11 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { primaryMempool.insertTx(firstTx) // run the router - rts.start(t) + rts.start(ctx, t) closer := tmsync.NewCloser() primaryReactor.peerWG.Add(1) - go primaryReactor.broadcastTxRoutine(secondary, closer) + go primaryReactor.broadcastTxRoutine(ctx, secondary, closer) wg := &sync.WaitGroup{} for i := 0; i < 50; i++ { @@ -206,13 +199,11 @@ func TestReactorBroadcastTxs(t *testing.T) { txs := checkTxs(ctx, t, rts.reactors[primary].mempool, numTxs, UnknownPeerID) // run the router - rts.start(t) + rts.start(ctx, t) // Wait till all secondary suites (reactor) received all mempool txs from the // primary suite (node). rts.waitForTxns(t, convertTex(txs), secondaries...) - - rts.stop(t) } // regression test for https://github.com/tendermint/tendermint/issues/5408 @@ -228,7 +219,7 @@ func TestReactorConcurrency(t *testing.T) { primary := rts.nodes[0] secondary := rts.nodes[1] - rts.start(t) + rts.start(ctx, t) var wg sync.WaitGroup @@ -292,7 +283,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { peerID := uint16(1) _ = checkTxs(ctx, t, rts.mempools[primary], numTxs, peerID) - rts.start(t) + rts.start(ctx, t) time.Sleep(100 * time.Millisecond) @@ -328,7 +319,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { ) require.NoError(t, err) - rts.start(t) + rts.start(ctx, t) rts.reactors[primary].mempool.Flush() rts.reactors[secondary].mempool.Flush() @@ -421,7 +412,7 @@ func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { primary := rts.nodes[0] secondary := rts.nodes[1] - rts.start(t) + rts.start(ctx, t) // disconnect peer rts.peerChans[primary] <- p2p.PeerUpdate{ diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 94f248a8c..67d7a42f0 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -221,8 +221,8 @@ func (c *MConnection) OnStart(ctx context.Context) error { c.quitSendRoutine = make(chan struct{}) c.doneSendRoutine = make(chan struct{}) c.quitRecvRoutine = make(chan struct{}) - go c.sendRoutine() - go c.recvRoutine() + go c.sendRoutine(ctx) + go c.recvRoutine(ctx) return nil } @@ -332,7 +332,7 @@ func (c *MConnection) Send(chID ChannelID, msgBytes []byte) bool { } // sendRoutine polls for packets to send from channels. -func (c *MConnection) sendRoutine() { +func (c *MConnection) sendRoutine(ctx context.Context) { defer c._recover() protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter) @@ -382,6 +382,8 @@ FOR_LOOP: } c.sendMonitor.Update(_n) c.flush() + case <-ctx.Done(): + break FOR_LOOP case <-c.quitSendRoutine: break FOR_LOOP case <-c.send: @@ -469,7 +471,7 @@ func (c *MConnection) sendPacketMsg() bool { // After a whole message has been assembled, it's pushed to onReceive(). // Blocks depending on how the connection is throttled. // Otherwise, it never blocks. -func (c *MConnection) recvRoutine() { +func (c *MConnection) recvRoutine(ctx context.Context) { defer c._recover() protoReader := protoio.NewDelimitedReader(c.bufConnReader, c._maxPacketMsgSize) @@ -502,6 +504,7 @@ FOR_LOOP: // stopServices was invoked and we are shutting down // receiving is excpected to fail since we will close the connection select { + case <-ctx.Done(): case <-c.quitRecvRoutine: break FOR_LOOP default: diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 6ee253b3c..6fc5d7c11 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -67,14 +67,14 @@ func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Networ // Start starts the network by setting up a list of node addresses to dial in // addition to creating a peer update subscription for each node. Finally, all // nodes are connected to each other. -func (n *Network) Start(t *testing.T) { +func (n *Network) Start(ctx context.Context, t *testing.T) { // Set up a list of node addresses to dial, and a peer update subscription // for each node. dialQueue := []p2p.NodeAddress{} subs := map[types.NodeID]*p2p.PeerUpdates{} for _, node := range n.Nodes { dialQueue = append(dialQueue, node.NodeAddress) - subs[node.NodeID] = node.PeerManager.Subscribe() + subs[node.NodeID] = node.PeerManager.Subscribe(ctx) defer subs[node.NodeID].Close() } @@ -93,6 +93,8 @@ func (n *Network) Start(t *testing.T) { require.True(t, added) select { + case <-ctx.Done(): + require.Fail(t, "operation canceled") case peerUpdate := <-sourceSub.Updates(): require.Equal(t, p2p.PeerUpdate{ NodeID: targetNode.NodeID, @@ -104,6 +106,8 @@ func (n *Network) Start(t *testing.T) { } select { + case <-ctx.Done(): + require.Fail(t, "operation canceled") case peerUpdate := <-targetSub.Updates(): require.Equal(t, p2p.PeerUpdate{ NodeID: sourceNode.NodeID, @@ -135,12 +139,13 @@ func (n *Network) NodeIDs() []types.NodeID { // MakeChannels makes a channel on all nodes and returns them, automatically // doing error checks and cleanups. func (n *Network) MakeChannels( + ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannel(t, chDesc) + channels[node.NodeID] = node.MakeChannel(ctx, t, chDesc) } return channels } @@ -149,12 +154,13 @@ func (n *Network) MakeChannels( // automatically doing error checks. The caller must ensure proper cleanup of // all the channels. func (n *Network) MakeChannelsNoCleanup( + ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) map[types.NodeID]*p2p.Channel { channels := map[types.NodeID]*p2p.Channel{} for _, node := range n.Nodes { - channels[node.NodeID] = node.MakeChannelNoCleanup(t, chDesc) + channels[node.NodeID] = node.MakeChannelNoCleanup(ctx, t, chDesc) } return channels } @@ -181,14 +187,14 @@ func (n *Network) Peers(id types.NodeID) []*Node { // Remove removes a node from the network, stopping it and waiting for all other // nodes to pick up the disconnection. -func (n *Network) Remove(t *testing.T, id types.NodeID) { +func (n *Network) Remove(ctx context.Context, t *testing.T, id types.NodeID) { require.Contains(t, n.Nodes, id) node := n.Nodes[id] delete(n.Nodes, id) subs := []*p2p.PeerUpdates{} for _, peer := range n.Nodes { - sub := peer.PeerManager.Subscribe() + sub := peer.PeerManager.Subscribe(ctx) defer sub.Close() subs = append(subs, sub) } @@ -243,6 +249,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) require.NoError(t, err) router, err := p2p.NewRouter( + ctx, n.logger, p2p.NopMetrics(), nodeInfo, @@ -279,15 +286,17 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) // test cleanup, it also checks that the channel is empty, to make sure // all expected messages have been asserted. func (n *Node) MakeChannel( + ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - channel, err := n.Router.OpenChannel(chDesc) + ctx, cancel := context.WithCancel(ctx) + channel, err := n.Router.OpenChannel(ctx, chDesc) require.NoError(t, err) require.Contains(t, n.Router.NodeInfo().Channels, byte(chDesc.ID)) t.Cleanup(func() { RequireEmpty(t, channel) - channel.Close() + cancel() }) return channel } @@ -295,10 +304,11 @@ func (n *Node) MakeChannel( // MakeChannelNoCleanup opens a channel, with automatic error handling. The // caller must ensure proper cleanup of the channel. func (n *Node) MakeChannelNoCleanup( + ctx context.Context, t *testing.T, chDesc *p2p.ChannelDescriptor, ) *p2p.Channel { - channel, err := n.Router.OpenChannel(chDesc) + channel, err := n.Router.OpenChannel(ctx, chDesc) require.NoError(t, err) return channel } @@ -307,9 +317,8 @@ func (n *Node) MakeChannelNoCleanup( // It checks that all updates have been consumed during cleanup. func (n *Node) MakePeerUpdates(ctx context.Context, t *testing.T) *p2p.PeerUpdates { t.Helper() - sub := n.PeerManager.Subscribe() + sub := n.PeerManager.Subscribe(ctx) t.Cleanup(func() { - t.Helper() RequireNoUpdates(ctx, t, sub) sub.Close() }) @@ -320,11 +329,10 @@ func (n *Node) MakePeerUpdates(ctx context.Context, t *testing.T) *p2p.PeerUpdat // MakePeerUpdatesNoRequireEmpty opens a peer update subscription, with automatic cleanup. // It does *not* check that all updates have been consumed, but will // close the update channel. -func (n *Node) MakePeerUpdatesNoRequireEmpty(t *testing.T) *p2p.PeerUpdates { - sub := n.PeerManager.Subscribe() - t.Cleanup(func() { - sub.Close() - }) +func (n *Node) MakePeerUpdatesNoRequireEmpty(ctx context.Context, t *testing.T) *p2p.PeerUpdates { + sub := n.PeerManager.Subscribe(ctx) + + t.Cleanup(sub.Close) return sub } diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index 106063bbd..3a7731829 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -31,13 +31,8 @@ func RequireReceive(t *testing.T, channel *p2p.Channel, expect p2p.Envelope) { defer timer.Stop() select { - case e, ok := <-channel.In: - require.True(t, ok, "channel %v is closed", channel.ID) + case e := <-channel.In: require.Equal(t, expect, e) - - case <-channel.Done(): - require.Fail(t, "channel %v is closed", channel.ID) - case <-timer.C: require.Fail(t, "timed out waiting for message", "%v on channel %v", expect, channel.ID) } @@ -52,17 +47,13 @@ func RequireReceiveUnordered(t *testing.T, channel *p2p.Channel, expect []p2p.En actual := []p2p.Envelope{} for { select { - case e, ok := <-channel.In: - require.True(t, ok, "channel %v is closed", channel.ID) + case e := <-channel.In: actual = append(actual, e) if len(actual) == len(expect) { require.ElementsMatch(t, expect, actual) return } - case <-channel.Done(): - require.Fail(t, "channel %v is closed", channel.ID) - case <-timer.C: require.ElementsMatch(t, expect, actual) return diff --git a/internal/p2p/peermanager.go b/internal/p2p/peermanager.go index 8c37cc1ff..0ab0128ca 100644 --- a/internal/p2p/peermanager.go +++ b/internal/p2p/peermanager.go @@ -513,7 +513,7 @@ func (m *PeerManager) TryDialNext() (NodeAddress, error) { // for dialing again when appropriate (possibly after a retry timeout). // // FIXME: This should probably delete or mark bad addresses/peers after some time. -func (m *PeerManager) DialFailed(address NodeAddress) error { +func (m *PeerManager) DialFailed(ctx context.Context, address NodeAddress) error { m.mtx.Lock() defer m.mtx.Unlock() @@ -553,6 +553,7 @@ func (m *PeerManager) DialFailed(address NodeAddress) error { case <-timer.C: m.dialWaker.Wake() case <-m.closeCh: + case <-ctx.Done(): } }() } else { @@ -835,7 +836,7 @@ func (m *PeerManager) Advertise(peerID types.NodeID, limit uint16) []NodeAddress // Subscribe subscribes to peer updates. The caller must consume the peer // updates in a timely fashion and close the subscription when done, otherwise // the PeerManager will halt. -func (m *PeerManager) Subscribe() *PeerUpdates { +func (m *PeerManager) Subscribe(ctx context.Context) *PeerUpdates { // FIXME: We use a size 1 buffer here. When we broadcast a peer update // we have to loop over all of the subscriptions, and we want to avoid // having to block and wait for a context switch before continuing on @@ -843,7 +844,7 @@ func (m *PeerManager) Subscribe() *PeerUpdates { // compounding. Limiting it to 1 means that the subscribers are still // reasonably in sync. However, this should probably be benchmarked. peerUpdates := NewPeerUpdates(make(chan PeerUpdate, 1), 1) - m.Register(peerUpdates) + m.Register(ctx, peerUpdates) return peerUpdates } @@ -855,7 +856,7 @@ func (m *PeerManager) Subscribe() *PeerUpdates { // The caller must consume the peer updates from this PeerUpdates // instance in a timely fashion and close the subscription when done, // otherwise the PeerManager will halt. -func (m *PeerManager) Register(peerUpdates *PeerUpdates) { +func (m *PeerManager) Register(ctx context.Context, peerUpdates *PeerUpdates) { m.mtx.Lock() m.subscriptions[peerUpdates] = peerUpdates m.mtx.Unlock() @@ -867,6 +868,8 @@ func (m *PeerManager) Register(peerUpdates *PeerUpdates) { return case <-m.closeCh: return + case <-ctx.Done(): + return case pu := <-peerUpdates.routerUpdatesCh: m.processPeerEvent(pu) } @@ -880,6 +883,7 @@ func (m *PeerManager) Register(peerUpdates *PeerUpdates) { delete(m.subscriptions, peerUpdates) m.mtx.Unlock() case <-m.closeCh: + case <-ctx.Done(): } }() } diff --git a/internal/p2p/peermanager_scoring_test.go b/internal/p2p/peermanager_scoring_test.go index edb5fc6fc..fe23767c4 100644 --- a/internal/p2p/peermanager_scoring_test.go +++ b/internal/p2p/peermanager_scoring_test.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "strings" "testing" "time" @@ -29,6 +30,9 @@ func TestPeerScoring(t *testing.T) { require.NoError(t, err) require.True(t, added) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Run("Synchronous", func(t *testing.T) { // update the manager and make sure it's correct require.EqualValues(t, 0, peerManager.Scores()[id]) @@ -53,7 +57,7 @@ func TestPeerScoring(t *testing.T) { }) t.Run("AsynchronousIncrement", func(t *testing.T) { start := peerManager.Scores()[id] - pu := peerManager.Subscribe() + pu := peerManager.Subscribe(ctx) defer pu.Close() pu.SendUpdate(PeerUpdate{ NodeID: id, @@ -67,7 +71,7 @@ func TestPeerScoring(t *testing.T) { }) t.Run("AsynchronousDecrement", func(t *testing.T) { start := peerManager.Scores()[id] - pu := peerManager.Subscribe() + pu := peerManager.Subscribe(ctx) defer pu.Close() pu.SendUpdate(PeerUpdate{ NodeID: id, diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 28efe63dd..cf1b0707e 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -343,7 +343,7 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { require.Fail(t, "unexpected retry") } - require.NoError(t, peerManager.DialFailed(a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) } } @@ -401,17 +401,20 @@ func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { require.Zero(t, dial) // Spawn a goroutine to fail a's dial attempt. + sig := make(chan struct{}) go func() { + defer close(sig) time.Sleep(200 * time.Millisecond) - require.NoError(t, peerManager.DialFailed(a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) }() // This should make b available for dialing (not a, retries are disabled). - ctx, cancel = context.WithTimeout(ctx, 3*time.Second) - defer cancel() - dial, err = peerManager.DialNext(ctx) + opctx, opcancel := context.WithTimeout(ctx, 3*time.Second) + defer opcancel() + dial, err = peerManager.DialNext(opctx) require.NoError(t, err) require.Equal(t, b, dial) + <-sig } func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { @@ -431,7 +434,7 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { dial, err := peerManager.TryDialNext() require.NoError(t, err) require.Equal(t, a, dial) - require.NoError(t, peerManager.DialFailed(dial)) + require.NoError(t, peerManager.DialFailed(ctx, dial)) failed := time.Now() // The retry timer should unblock DialNext and make a available again after @@ -680,6 +683,9 @@ func TestPeerManager_TryDialNext_DialingConnected(t *testing.T) { } func TestPeerManager_TryDialNext_Multiple(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + aID := types.NodeID(strings.Repeat("a", 40)) bID := types.NodeID(strings.Repeat("b", 40)) addresses := []p2p.NodeAddress{ @@ -704,7 +710,7 @@ func TestPeerManager_TryDialNext_Multiple(t *testing.T) { address, err := peerManager.TryDialNext() require.NoError(t, err) require.NotZero(t, address) - require.NoError(t, peerManager.DialFailed(address)) + require.NoError(t, peerManager.DialFailed(ctx, address)) dial = append(dial, address) } require.ElementsMatch(t, dial, addresses) @@ -729,13 +735,16 @@ func TestPeerManager_DialFailed(t *testing.T) { require.NoError(t, err) require.True(t, added) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Dialing and then calling DialFailed with a different address (same // NodeID) should unmark as dialing and allow us to dial the other address // again, but not register the failed address. dial, err := peerManager.TryDialNext() require.NoError(t, err) require.Equal(t, a, dial) - require.NoError(t, peerManager.DialFailed(p2p.NodeAddress{ + require.NoError(t, peerManager.DialFailed(ctx, p2p.NodeAddress{ Protocol: "tcp", NodeID: aID, Hostname: "localhost"})) require.Equal(t, []p2p.NodeAddress{a}, peerManager.Addresses(aID)) @@ -744,15 +753,18 @@ func TestPeerManager_DialFailed(t *testing.T) { require.Equal(t, a, dial) // Calling DialFailed on same address twice should be fine. - require.NoError(t, peerManager.DialFailed(a)) - require.NoError(t, peerManager.DialFailed(a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) // DialFailed on an unknown peer shouldn't error or add it. - require.NoError(t, peerManager.DialFailed(b)) + require.NoError(t, peerManager.DialFailed(ctx, b)) require.Equal(t, []types.NodeID{aID}, peerManager.Peers()) } func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} c := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("c", 40))} @@ -792,7 +804,7 @@ func TestPeerManager_DialFailed_UnreservePeer(t *testing.T) { require.Empty(t, dial) // Failing b's dial will now make c available for dialing. - require.NoError(t, peerManager.DialFailed(b)) + require.NoError(t, peerManager.DialFailed(ctx, b)) dial, err = peerManager.TryDialNext() require.NoError(t, err) require.Equal(t, c, dial) @@ -1274,10 +1286,13 @@ func TestPeerManager_Ready(t *testing.T) { a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() // Connecting to a should still have it as status down. @@ -1488,7 +1503,10 @@ func TestPeerManager_Disconnected(t *testing.T) { peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - sub := peerManager.Subscribe() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sub := peerManager.Subscribe(ctx) defer sub.Close() // Disconnecting an unknown peer does nothing. @@ -1573,13 +1591,16 @@ func TestPeerManager_Errored(t *testing.T) { } func TestPeerManager_Subscribe(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) // This tests all subscription events for full peer lifecycles. - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() added, err := peerManager.Add(a) @@ -1629,17 +1650,20 @@ func TestPeerManager_Subscribe(t *testing.T) { require.Equal(t, a, dial) require.Empty(t, sub.Updates()) - require.NoError(t, peerManager.DialFailed(a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) require.Empty(t, sub.Updates()) } func TestPeerManager_Subscribe_Close(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() added, err := peerManager.Add(a) @@ -1659,6 +1683,9 @@ func TestPeerManager_Subscribe_Close(t *testing.T) { } func TestPeerManager_Subscribe_Broadcast(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Cleanup(leaktest.Check(t)) a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} @@ -1666,11 +1693,11 @@ func TestPeerManager_Subscribe_Broadcast(t *testing.T) { peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) - s1 := peerManager.Subscribe() + s1 := peerManager.Subscribe(ctx) defer s1.Close() - s2 := peerManager.Subscribe() + s2 := peerManager.Subscribe(ctx) defer s2.Close() - s3 := peerManager.Subscribe() + s3 := peerManager.Subscribe(ctx) defer s3.Close() // Connecting to a peer should send updates on all subscriptions. @@ -1705,6 +1732,9 @@ func TestPeerManager_Close(t *testing.T) { // leaktest will check that spawned goroutines are closed. t.Cleanup(leaktest.CheckTimeout(t, 1*time.Second)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ @@ -1714,7 +1744,7 @@ func TestPeerManager_Close(t *testing.T) { // This subscription isn't closed, but PeerManager.Close() // should reap the spawned goroutine. - _ = peerManager.Subscribe() + _ = peerManager.Subscribe(ctx) // This dial failure will start a retry timer for 10 seconds, which // should be reaped. @@ -1724,7 +1754,7 @@ func TestPeerManager_Close(t *testing.T) { dial, err := peerManager.TryDialNext() require.NoError(t, err) require.Equal(t, a, dial) - require.NoError(t, peerManager.DialFailed(a)) + require.NoError(t, peerManager.DialFailed(ctx, a)) // This should clean up the goroutines. peerManager.Close() diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 53cd079bc..609d86e4c 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -153,17 +153,12 @@ func (r *Reactor) OnStop() { // p2p Channels should execute Close(). close(r.closeCh) - // Wait for all p2p Channels to be closed before returning. This ensures we - // can easily reason about synchronization of all p2p Channels and ensure no - // panics will occur. - <-r.pexCh.Done() <-r.peerUpdates.Done() } // processPexCh implements a blocking event loop where we listen for p2p // Envelope messages from the pexCh. func (r *Reactor) processPexCh(ctx context.Context) { - defer r.pexCh.Close() timer := time.NewTimer(0) defer timer.Stop() for { diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 63b182fc0..5a061d76d 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -37,7 +37,7 @@ func TestReactorBasic(t *testing.T) { MockNodes: 1, TotalNodes: 2, }) - testNet.connectAll(t) + testNet.connectAll(ctx, t) testNet.start(ctx, t) // assert that the mock node receives a request from the real node @@ -59,7 +59,7 @@ func TestReactorConnectFullNetwork(t *testing.T) { // make every node be only connected with one other node (it actually ends up // being two because of two way connections but oh well) - testNet.connectN(t, 1) + testNet.connectN(ctx, t, 1) testNet.start(ctx, t) // assert that all nodes add each other in the network @@ -106,7 +106,7 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { MockNodes: 1, TotalNodes: 3, }) - testNet.connectAll(t) + testNet.connectAll(ctx, t) testNet.start(ctx, t) // firstNode sends the secondNode an unrequested response @@ -116,7 +116,7 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { testNet.sendResponse(t, firstNode, secondNode, []int{thirdNode}) // secondNode should evict the firstNode - testNet.listenForPeerUpdate(t, secondNode, firstNode, p2p.PeerStatusDown, shortWait) + testNet.listenForPeerUpdate(ctx, t, secondNode, firstNode, p2p.PeerStatusDown, shortWait) } func TestReactorNeverSendsTooManyPeers(t *testing.T) { @@ -127,7 +127,7 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { MockNodes: 1, TotalNodes: 2, }) - testNet.connectAll(t) + testNet.connectAll(ctx, t) testNet.start(ctx, t) testNet.addNodes(ctx, t, 110) @@ -199,7 +199,7 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { MaxConnected: 3, BufferSize: 8, }) - testNet.connectN(t, 1) + testNet.connectN(ctx, t, 1) testNet.start(ctx, t) // test that all nodes reach full capacity @@ -221,7 +221,7 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { MaxConnected: 25, BufferSize: 5, }) - testNet.connectN(t, 1) + testNet.connectN(ctx, t, 1) testNet.start(ctx, t) // assert that all nodes add each other in the network @@ -238,7 +238,7 @@ func TestReactorWithNetworkGrowth(t *testing.T) { TotalNodes: 5, BufferSize: 5, }) - testNet.connectAll(t) + testNet.connectAll(ctx, t) testNet.start(ctx, t) // assert that all nodes add each other in the network @@ -254,7 +254,7 @@ func TestReactorWithNetworkGrowth(t *testing.T) { require.True(t, testNet.reactors[node].IsRunning()) // we connect all new nodes to a single entry point and check that the // node can distribute the addresses to all the others - testNet.connectPeers(t, 0, i) + testNet.connectPeers(ctx, t, 0, i) } require.Len(t, testNet.reactors, 15) @@ -297,7 +297,6 @@ func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { reactor := pex.NewReactor(log.TestingLogger(), peerManager, pexCh, peerUpdates) require.NoError(t, reactor.Start(ctx)) t.Cleanup(func() { - pexCh.Close() peerUpdates.Close() reactor.Wait() }) @@ -370,13 +369,13 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT // NOTE: we don't assert that the channels get drained after stopping the // reactor - rts.pexChannels = rts.network.MakeChannelsNoCleanup(t, pex.ChannelDescriptor()) + rts.pexChannels = rts.network.MakeChannelsNoCleanup(ctx, t, pex.ChannelDescriptor()) idx := 0 for nodeID := range rts.network.Nodes { rts.peerChans[nodeID] = make(chan p2p.PeerUpdate, chBuf) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], chBuf) - rts.network.Nodes[nodeID].PeerManager.Register(rts.peerUpdates[nodeID]) + rts.network.Nodes[nodeID].PeerManager.Register(ctx, rts.peerUpdates[nodeID]) // the first nodes in the array are always mock nodes if idx < opts.MockNodes { @@ -402,11 +401,9 @@ func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorT reactor.Wait() require.False(t, reactor.IsRunning()) } - rts.pexChannels[nodeID].Close() rts.peerUpdates[nodeID].Close() } for _, nodeID := range rts.mocks { - rts.pexChannels[nodeID].Close() rts.peerUpdates[nodeID].Close() } }) @@ -434,10 +431,10 @@ func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int }) r.network.Nodes[node.NodeID] = node nodeID := node.NodeID - r.pexChannels[nodeID] = node.MakeChannelNoCleanup(t, pex.ChannelDescriptor()) + r.pexChannels[nodeID] = node.MakeChannelNoCleanup(ctx, t, pex.ChannelDescriptor()) r.peerChans[nodeID] = make(chan p2p.PeerUpdate, r.opts.BufferSize) r.peerUpdates[nodeID] = p2p.NewPeerUpdates(r.peerChans[nodeID], r.opts.BufferSize) - r.network.Nodes[nodeID].PeerManager.Register(r.peerUpdates[nodeID]) + r.network.Nodes[nodeID].PeerManager.Register(ctx, r.peerUpdates[nodeID]) r.reactors[nodeID] = pex.NewReactor( r.logger.With("nodeID", nodeID), r.network.Nodes[nodeID].PeerManager, @@ -537,17 +534,21 @@ func (r *reactorTestSuite) listenForResponse( } func (r *reactorTestSuite) listenForPeerUpdate( + ctx context.Context, t *testing.T, onNode, withNode int, status p2p.PeerStatus, waitPeriod time.Duration, ) { on, with := r.checkNodePair(t, onNode, withNode) - sub := r.network.Nodes[on].PeerManager.Subscribe() + sub := r.network.Nodes[on].PeerManager.Subscribe(ctx) defer sub.Close() timesUp := time.After(waitPeriod) for { select { + case <-ctx.Done(): + require.Fail(t, "operation canceled") + return case peerUpdate := <-sub.Updates(): if peerUpdate.NodeID == with { require.Equal(t, status, peerUpdate.Status) @@ -612,25 +613,25 @@ func (r *reactorTestSuite) requireNumberOfPeers( ) } -func (r *reactorTestSuite) connectAll(t *testing.T) { - r.connectN(t, r.total-1) +func (r *reactorTestSuite) connectAll(ctx context.Context, t *testing.T) { + r.connectN(ctx, t, r.total-1) } // connects all nodes with n other nodes -func (r *reactorTestSuite) connectN(t *testing.T, n int) { +func (r *reactorTestSuite) connectN(ctx context.Context, t *testing.T, n int) { if n >= r.total { require.Fail(t, "connectN: n must be less than the size of the network - 1") } for i := 0; i < r.total; i++ { for j := 0; j < n; j++ { - r.connectPeers(t, i, (i+j+1)%r.total) + r.connectPeers(ctx, t, i, (i+j+1)%r.total) } } } // connects node1 to node2 -func (r *reactorTestSuite) connectPeers(t *testing.T, sourceNode, targetNode int) { +func (r *reactorTestSuite) connectPeers(ctx context.Context, t *testing.T, sourceNode, targetNode int) { t.Helper() node1, node2 := r.checkNodePair(t, sourceNode, targetNode) r.logger.Info("connecting peers", "sourceNode", sourceNode, "targetNode", targetNode) @@ -647,9 +648,9 @@ func (r *reactorTestSuite) connectPeers(t *testing.T, sourceNode, targetNode int return } - sourceSub := n1.PeerManager.Subscribe() + sourceSub := n1.PeerManager.Subscribe(ctx) defer sourceSub.Close() - targetSub := n2.PeerManager.Subscribe() + targetSub := n2.PeerManager.Subscribe(ctx) defer targetSub.Close() sourceAddress := n1.NodeAddress diff --git a/internal/p2p/pqueue.go b/internal/p2p/pqueue.go index 11cdbd130..b43bb806f 100644 --- a/internal/p2p/pqueue.go +++ b/internal/p2p/pqueue.go @@ -2,6 +2,7 @@ package p2p import ( "container/heap" + "context" "sort" "strconv" "time" @@ -140,8 +141,8 @@ func (s *pqScheduler) closed() <-chan struct{} { } // start starts non-blocking process that starts the priority queue scheduler. -func (s *pqScheduler) start() { - go s.process() +func (s *pqScheduler) start(ctx context.Context) { + go s.process(ctx) } // process starts a block process where we listen for Envelopes to enqueue. If @@ -153,7 +154,7 @@ func (s *pqScheduler) start() { // // After we attempt to enqueue the incoming Envelope, if the priority queue is // non-empty, we pop the top Envelope and send it on the dequeueCh. -func (s *pqScheduler) process() { +func (s *pqScheduler) process(ctx context.Context) { defer s.done.Close() for { @@ -267,7 +268,8 @@ func (s *pqScheduler) process() { return } } - + case <-ctx.Done(): + return case <-s.closer.Done(): return } diff --git a/internal/p2p/pqueue_test.go b/internal/p2p/pqueue_test.go index ffa7e39a8..181e6e7f7 100644 --- a/internal/p2p/pqueue_test.go +++ b/internal/p2p/pqueue_test.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "testing" "time" @@ -24,7 +25,10 @@ func TestCloseWhileDequeueFull(t *testing.T) { } } - go pqueue.process() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go pqueue.process(ctx) // sleep to allow context switch for process() to run time.Sleep(10 * time.Millisecond) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 29646e327..09397e974 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -62,8 +62,6 @@ type Channel struct { Error chan<- PeerError // peer error reporting messageType proto.Message // the channel's message type, used for unmarshaling - closeCh chan struct{} - closeOnce sync.Once } // NewChannel creates a new channel. It is primarily for internal and test @@ -81,26 +79,9 @@ func NewChannel( In: inCh, Out: outCh, Error: errCh, - closeCh: make(chan struct{}), } } -// Close closes the channel. Future sends on Out and Error will panic. The In -// channel remains open to avoid having to synchronize Router senders, which -// should use Done() to detect channel closure instead. -func (c *Channel) Close() { - c.closeOnce.Do(func() { - close(c.closeCh) - close(c.Out) - close(c.Error) - }) -} - -// Done returns a channel that's closed when Channel.Close() is called. -func (c *Channel) Done() <-chan struct{} { - return c.closeCh -} - // Wrapper is a Protobuf message that can contain a variety of inner messages // (e.g. via oneof fields). If a Channel's message type implements Wrapper, the // Router will automatically wrap outbound messages and unwrap inbound messages, @@ -272,6 +253,7 @@ type Router struct { // listening on appropriate interfaces, and will be closed by the Router when it // stops. func NewRouter( + ctx context.Context, logger log.Logger, metrics *Metrics, nodeInfo types.NodeInfo, @@ -310,7 +292,7 @@ func NewRouter( router.BaseService = service.NewBaseService(logger, "router", router) - qf, err := router.createQueueFactory() + qf, err := router.createQueueFactory(ctx) if err != nil { return nil, err } @@ -328,7 +310,7 @@ func NewRouter( return router, nil } -func (r *Router) createQueueFactory() (func(int) queue, error) { +func (r *Router) createQueueFactory(ctx context.Context) (func(int) queue, error) { switch r.options.QueueType { case queueTypeFifo: return newFIFOQueue, nil @@ -340,7 +322,7 @@ func (r *Router) createQueueFactory() (func(int) queue, error) { } q := newPQScheduler(r.logger, r.metrics, r.chDescs, uint(size)/2, uint(size)/2, defaultCapacity) - q.start() + q.start(ctx) return q }, nil @@ -355,7 +337,7 @@ func (r *Router) createQueueFactory() (func(int) queue, error) { // implement Wrapper to automatically (un)wrap multiple message types in a // wrapper message. The caller may provide a size to make the channel buffered, // which internally makes the inbound, outbound, and error channel buffered. -func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { +func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (*Channel, error) { r.channelMtx.Lock() defer r.channelMtx.Unlock() @@ -396,7 +378,7 @@ func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { queue.close() }() - r.routeChannel(id, outCh, errCh, wrapper) + r.routeChannel(ctx, id, outCh, errCh, wrapper) }() return channel, nil @@ -408,6 +390,7 @@ func (r *Router) OpenChannel(chDesc *ChannelDescriptor) (*Channel, error) { // closed, or the Router is stopped. wrapper is an optional message wrapper // for messages, see Wrapper for details. func (r *Router) routeChannel( + ctx context.Context, chID ChannelID, outCh <-chan Envelope, errCh <-chan PeerError, @@ -504,7 +487,8 @@ func (r *Router) routeChannel( r.logger.Error("peer error, evicting", "peer", peerError.NodeID, "err", peerError.Err) r.peerManager.Errored(peerError.NodeID, peerError.Err) - + case <-ctx.Done(): + return case <-r.stopCh: return } @@ -561,9 +545,9 @@ func (r *Router) dialSleep(ctx context.Context) { // acceptPeers accepts inbound connections from peers on the given transport, // and spawns goroutines that route messages to/from them. -func (r *Router) acceptPeers(transport Transport) { +func (r *Router) acceptPeers(ctx context.Context, transport Transport) { r.logger.Debug("starting accept routine", "transport", transport) - ctx := r.stopCtx() + for { conn, err := transport.Accept() switch err { @@ -640,13 +624,12 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) { return } - r.routePeer(peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) + r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels)) } // dialPeers maintains outbound connections to peers by dialing them. -func (r *Router) dialPeers() { +func (r *Router) dialPeers(ctx context.Context) { r.logger.Debug("starting dial routine") - ctx := r.stopCtx() addresses := make(chan NodeAddress) wg := &sync.WaitGroup{} @@ -709,7 +692,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return case err != nil: r.logger.Error("failed to dial peer", "peer", address, "err", err) - if err = r.peerManager.DialFailed(address); err != nil { + if err = r.peerManager.DialFailed(ctx, address); err != nil { r.logger.Error("failed to report dial failure", "peer", address, "err", err) } return @@ -722,7 +705,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { return case err != nil: r.logger.Error("failed to handshake with peer", "peer", address, "err", err) - if err = r.peerManager.DialFailed(address); err != nil { + if err = r.peerManager.DialFailed(ctx, address); err != nil { r.logger.Error("failed to report dial failure", "peer", address, "err", err) } conn.Close() @@ -737,7 +720,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) { } // routePeer (also) calls connection close - go r.routePeer(address.NodeID, conn, toChannelIDs(peerInfo.Channels)) + go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels)) } func (r *Router) getOrMakeQueue(peerID types.NodeID, channels channelIDs) queue { @@ -852,7 +835,7 @@ func (r *Router) runWithPeerMutex(fn func() error) error { // routePeer routes inbound and outbound messages between a peer and the reactor // channels. It will close the given connection and send queue when done, or if // they are closed elsewhere it will cause this method to shut down and return. -func (r *Router) routePeer(peerID types.NodeID, conn Connection, channels channelIDs) { +func (r *Router) routePeer(ctx context.Context, peerID types.NodeID, conn Connection, channels channelIDs) { r.metrics.Peers.Add(1) r.peerManager.Ready(peerID) @@ -874,27 +857,46 @@ func (r *Router) routePeer(peerID types.NodeID, conn Connection, channels channe errCh := make(chan error, 2) go func() { - errCh <- r.receivePeer(peerID, conn) + select { + case errCh <- r.receivePeer(peerID, conn): + case <-ctx.Done(): + } }() go func() { - errCh <- r.sendPeer(peerID, conn, sendQueue) + select { + case errCh <- r.sendPeer(peerID, conn, sendQueue): + case <-ctx.Done(): + } }() - err := <-errCh + var err error + select { + case err = <-errCh: + case <-ctx.Done(): + } + _ = conn.Close() sendQueue.close() - if e := <-errCh; err == nil { + select { + case <-ctx.Done(): + case e := <-errCh: // The first err was nil, so we update it with the second err, which may // or may not be nil. + if err == nil { + err = e + } + } + + // if the context was canceled + if e := ctx.Err(); err == nil && e != nil { err = e } switch err { case nil, io.EOF: r.logger.Info("peer disconnected", "peer", peerID, "endpoint", conn) - default: r.logger.Error("peer failure", "peer", peerID, "endpoint", conn, "err", err) } @@ -988,9 +990,8 @@ func (r *Router) sendPeer(peerID types.NodeID, conn Connection, peerQueue queue) } // evictPeers evicts connected peers as requested by the peer manager. -func (r *Router) evictPeers() { +func (r *Router) evictPeers(ctx context.Context) { r.logger.Debug("starting evict routine") - ctx := r.stopCtx() for { peerID, err := r.peerManager.EvictNext(ctx) @@ -1040,11 +1041,11 @@ func (r *Router) OnStart(ctx context.Context) error { "transports", len(r.transports), ) - go r.dialPeers() - go r.evictPeers() + go r.dialPeers(ctx) + go r.evictPeers(ctx) for _, transport := range r.transports { - go r.acceptPeers(transport) + go r.acceptPeers(ctx, transport) } return nil @@ -1087,18 +1088,6 @@ func (r *Router) OnStop() { } } -// stopCtx returns a new context that is canceled when the router stops. -func (r *Router) stopCtx() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - - go func() { - <-r.stopCh - cancel() - }() - - return ctx -} - type channelIDs map[ChannelID]struct{} func toChannelIDs(bytes []byte) channelIDs { diff --git a/internal/p2p/router_init_test.go b/internal/p2p/router_init_test.go index c8bef696a..b2a8fe1a0 100644 --- a/internal/p2p/router_init_test.go +++ b/internal/p2p/router_init_test.go @@ -1,6 +1,7 @@ package p2p import ( + "context" "os" "testing" @@ -10,6 +11,9 @@ import ( ) func TestRouter_ConstructQueueFactory(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Run("ValidateOptionsPopulatesDefaultQueue", func(t *testing.T) { opts := RouterOptions{} require.NoError(t, opts.Validate()) @@ -18,21 +22,21 @@ func TestRouter_ConstructQueueFactory(t *testing.T) { t.Run("Default", func(t *testing.T) { require.Zero(t, os.Getenv("TM_P2P_QUEUE")) opts := RouterOptions{} - r, err := NewRouter(log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) + r, err := NewRouter(ctx, log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) require.NoError(t, err) _, ok := r.queueFactory(1).(*fifoQueue) require.True(t, ok) }) t.Run("Fifo", func(t *testing.T) { opts := RouterOptions{QueueType: queueTypeFifo} - r, err := NewRouter(log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) + r, err := NewRouter(ctx, log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) require.NoError(t, err) _, ok := r.queueFactory(1).(*fifoQueue) require.True(t, ok) }) t.Run("Priority", func(t *testing.T) { opts := RouterOptions{QueueType: queueTypePriority} - r, err := NewRouter(log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) + r, err := NewRouter(ctx, log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) require.NoError(t, err) q, ok := r.queueFactory(1).(*pqScheduler) require.True(t, ok) @@ -40,7 +44,7 @@ func TestRouter_ConstructQueueFactory(t *testing.T) { }) t.Run("NonExistant", func(t *testing.T) { opts := RouterOptions{QueueType: "fast"} - _, err := NewRouter(log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) + _, err := NewRouter(ctx, log.NewNopLogger(), nil, types.NodeInfo{}, nil, nil, nil, nil, opts) require.Error(t, err) require.Contains(t, err.Error(), "fast") }) @@ -48,7 +52,7 @@ func TestRouter_ConstructQueueFactory(t *testing.T) { r := &Router{} require.Zero(t, r.options.QueueType) - fn, err := r.createQueueFactory() + fn, err := r.createQueueFactory(ctx) require.Error(t, err) require.Nil(t, fn) }) diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index c77e9e44d..8a4c9e4bc 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -27,7 +27,7 @@ import ( "github.com/tendermint/tendermint/types" ) -func echoReactor(channel *p2p.Channel) { +func echoReactor(ctx context.Context, channel *p2p.Channel) { for { select { case envelope := <-channel.In: @@ -37,7 +37,7 @@ func echoReactor(channel *p2p.Channel) { Message: &p2ptest.Message{Value: value}, } - case <-channel.Done(): + case <-ctx.Done(): return } } @@ -53,13 +53,13 @@ func TestRouter_Network(t *testing.T) { network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 8}) local := network.RandomNode() peers := network.Peers(local.NodeID) - channels := network.MakeChannels(t, chDesc) + channels := network.MakeChannels(ctx, t, chDesc) - network.Start(t) + network.Start(ctx, t) channel := channels[local.NodeID] for _, peer := range peers { - go echoReactor(channels[peer.NodeID]) + go echoReactor(ctx, channels[peer.NodeID]) } // Sending a message to each peer should work. @@ -86,7 +86,7 @@ func TestRouter_Network(t *testing.T) { // We then submit an error for a peer, and watch it get disconnected and // then reconnected as the router retries it. - peerUpdates := local.MakePeerUpdatesNoRequireEmpty(t) + peerUpdates := local.MakePeerUpdatesNoRequireEmpty(ctx, t) channel.Error <- p2p.PeerError{ NodeID: peers[0].NodeID, Err: errors.New("boom"), @@ -100,12 +100,16 @@ func TestRouter_Network(t *testing.T) { func TestRouter_Channel_Basic(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Set up a router with no transports (so no peers). peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) defer peerManager.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -117,33 +121,34 @@ func TestRouter_Channel_Basic(t *testing.T) { ) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - require.NoError(t, router.Start(ctx)) t.Cleanup(router.Wait) // Opening a channel should work. - channel, err := router.OpenChannel(chDesc) + chctx, chcancel := context.WithCancel(ctx) + defer chcancel() + + channel, err := router.OpenChannel(chctx, chDesc) require.NoError(t, err) require.Contains(t, router.NodeInfo().Channels, byte(chDesc.ID)) + require.NotNil(t, channel) // Opening the same channel again should fail. - _, err = router.OpenChannel(chDesc) + _, err = router.OpenChannel(ctx, chDesc) require.Error(t, err) // Opening a different channel should work. chDesc2 := &p2p.ChannelDescriptor{ID: 2, MessageType: &p2ptest.Message{}} - _, err = router.OpenChannel(chDesc2) + _, err = router.OpenChannel(ctx, chDesc2) require.NoError(t, err) require.Contains(t, router.NodeInfo().Channels, byte(chDesc2.ID)) // Closing the channel, then opening it again should be fine. - channel.Close() - time.Sleep(100 * time.Millisecond) // yes yes, but Close() is async... + chcancel() + time.Sleep(200 * time.Millisecond) // yes yes, but Close() is async... - channel, err = router.OpenChannel(chDesc) + channel, err = router.OpenChannel(ctx, chDesc) require.NoError(t, err) // We should be able to send on the channel, even though there are no peers. @@ -172,11 +177,11 @@ func TestRouter_Channel_SendReceive(t *testing.T) { ids := network.NodeIDs() aID, bID, cID := ids[0], ids[1], ids[2] - channels := network.MakeChannels(t, chDesc) + channels := network.MakeChannels(ctx, t, chDesc) a, b, c := channels[aID], channels[bID], channels[cID] - otherChannels := network.MakeChannels(t, p2ptest.MakeChannelDesc(9)) + otherChannels := network.MakeChannels(ctx, t, p2ptest.MakeChannelDesc(9)) - network.Start(t) + network.Start(ctx, t) // Sending a message a->b should work, and not send anything // further to a, b, or c. @@ -208,7 +213,7 @@ func TestRouter_Channel_SendReceive(t *testing.T) { p2ptest.RequireEmpty(t, a, b, c) // Removing b and sending to it should be dropped. - network.Remove(t, bID) + network.Remove(ctx, t, bID) p2ptest.RequireSend(t, a, p2p.Envelope{To: bID, Message: &p2ptest.Message{Value: "nob"}}) p2ptest.RequireEmpty(t, a, b, c) @@ -234,10 +239,10 @@ func TestRouter_Channel_Broadcast(t *testing.T) { ids := network.NodeIDs() aID, bID, cID, dID := ids[0], ids[1], ids[2], ids[3] - channels := network.MakeChannels(t, chDesc) + channels := network.MakeChannels(ctx, t, chDesc) a, b, c, d := channels[aID], channels[bID], channels[cID], channels[dID] - network.Start(t) + network.Start(ctx, t) // Sending a broadcast from b should work. p2ptest.RequireSend(t, b, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "foo"}}) @@ -247,7 +252,7 @@ func TestRouter_Channel_Broadcast(t *testing.T) { p2ptest.RequireEmpty(t, a, b, c, d) // Removing one node from the network shouldn't prevent broadcasts from working. - network.Remove(t, dID) + network.Remove(ctx, t, dID) p2ptest.RequireSend(t, a, p2p.Envelope{Broadcast: true, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireReceive(t, b, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) p2ptest.RequireReceive(t, c, p2p.Envelope{From: aID, Message: &p2ptest.Message{Value: "bar"}}) @@ -273,10 +278,10 @@ func TestRouter_Channel_Wrapper(t *testing.T) { RecvMessageCapacity: 10, } - channels := network.MakeChannels(t, chDesc) + channels := network.MakeChannels(ctx, t, chDesc) a, b := channels[aID], channels[bID] - network.Start(t) + network.Start(ctx, t) // Since wrapperMessage implements p2p.Wrapper and handles Message, it // should automatically wrap and unwrap sent messages -- we prepend the @@ -332,11 +337,11 @@ func TestRouter_Channel_Error(t *testing.T) { // Create a test network and open a channel on all nodes. network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) - network.Start(t) + network.Start(ctx, t) ids := network.NodeIDs() aID, bID := ids[0], ids[1] - channels := network.MakeChannels(t, chDesc) + channels := network.MakeChannels(ctx, t, chDesc) a := channels[aID] // Erroring b should cause it to be disconnected. It will reconnect shortly after. @@ -407,10 +412,11 @@ func TestRouter_AcceptPeers(t *testing.T) { require.NoError(t, err) defer peerManager.Close() - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -467,6 +473,7 @@ func TestRouter_AcceptPeers_Error(t *testing.T) { defer peerManager.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -488,6 +495,9 @@ func TestRouter_AcceptPeers_Error(t *testing.T) { func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Set up a mock transport that returns io.EOF once, which should prevent // the router from calling Accept again. mockTransport := &mocks.Transport{} @@ -502,6 +512,7 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { defer peerManager.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -513,9 +524,6 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { ) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - require.NoError(t, router.Start(ctx)) time.Sleep(time.Second) require.NoError(t, router.Stop()) @@ -557,6 +565,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { defer peerManager.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -659,10 +668,11 @@ func TestRouter_DialPeers(t *testing.T) { added, err := peerManager.Add(address) require.NoError(t, err) require.True(t, added) - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -750,6 +760,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { require.True(t, added) router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -822,10 +833,11 @@ func TestRouter_EvictPeers(t *testing.T) { require.NoError(t, err) defer peerManager.Close() - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -890,6 +902,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { defer peerManager.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -943,10 +956,11 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { require.NoError(t, err) defer peerManager.Close() - sub := peerManager.Subscribe() + sub := peerManager.Subscribe(ctx) defer sub.Close() router, err := p2p.NewRouter( + ctx, log.TestingLogger(), p2p.NopMetrics(), selfInfo, @@ -964,7 +978,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { Status: p2p.PeerStatusUp, }) - channel, err := router.OpenChannel(chDesc) + channel, err := router.OpenChannel(ctx, chDesc) require.NoError(t, err) channel.Out <- p2p.Envelope{ diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 0580ce1bf..b5c89502f 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -277,7 +277,12 @@ func (c *mConnConnection) Handshake( }() var err error mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) - errCh <- err + + select { + case errCh <- err: + case <-ctx.Done(): + } + }() select { @@ -315,21 +320,39 @@ func (c *mConnConnection) handshake( return nil, types.NodeInfo{}, nil, err } + wg := &sync.WaitGroup{} var pbPeerInfo p2pproto.NodeInfo errCh := make(chan error, 2) + wg.Add(1) go func() { + defer wg.Done() _, err := protoio.NewDelimitedWriter(secretConn).WriteMsg(nodeInfo.ToProto()) - errCh <- err - }() - go func() { - _, err := protoio.NewDelimitedReader(secretConn, types.MaxNodeInfoSize()).ReadMsg(&pbPeerInfo) - errCh <- err - }() - for i := 0; i < cap(errCh); i++ { - if err = <-errCh; err != nil { - return nil, types.NodeInfo{}, nil, err + select { + case errCh <- err: + case <-ctx.Done(): } + + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := protoio.NewDelimitedReader(secretConn, types.MaxNodeInfoSize()).ReadMsg(&pbPeerInfo) + select { + case errCh <- err: + case <-ctx.Done(): + } + }() + + wg.Wait() + + if err, ok := <-errCh; ok && err != nil { + return nil, types.NodeInfo{}, nil, err } + + if err := ctx.Err(); err != nil { + return nil, types.NodeInfo{}, nil, err + } + peerInfo, err := types.NodeInfoFromProto(&pbPeerInfo) if err != nil { return nil, types.NodeInfo{}, nil, err diff --git a/internal/rpc/core/events.go b/internal/rpc/core/events.go index d8e09f35e..73ca8a755 100644 --- a/internal/rpc/core/events.go +++ b/internal/rpc/core/events.go @@ -56,15 +56,18 @@ func (env *Environment) Subscribe(ctx *rpctypes.Context, query string) (*coretyp // Capture the current ID, since it can change in the future. subscriptionID := ctx.JSONReq.ID go func() { + opctx, opcancel := context.WithCancel(context.Background()) + defer opcancel() + for { - msg, err := sub.Next(context.Background()) + msg, err := sub.Next(opctx) if errors.Is(err, tmpubsub.ErrUnsubscribed) { // The subscription was removed by the client. return } else if errors.Is(err, tmpubsub.ErrTerminated) { // The subscription was terminated by the publisher. resp := rpctypes.RPCServerError(subscriptionID, err) - ok := ctx.WSConn.TryWriteRPCResponse(resp) + ok := ctx.WSConn.TryWriteRPCResponse(opctx, resp) if !ok { env.Logger.Info("Unable to write response (slow client)", "to", addr, "subscriptionID", subscriptionID, "err", err) @@ -78,7 +81,7 @@ func (env *Environment) Subscribe(ctx *rpctypes.Context, query string) (*coretyp Data: msg.Data(), Events: msg.Events(), }) - wctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + wctx, cancel := context.WithTimeout(opctx, 10*time.Second) err = ctx.WSConn.WriteRPCResponse(wctx, resp) cancel() if err != nil { diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index 6566f823b..10edd3827 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -211,11 +211,11 @@ func NewReactor( // The caller must be sure to execute OnStop to ensure the outbound p2p Channels are // closed. No error is returned. func (r *Reactor) OnStart(ctx context.Context) error { - go r.processCh(r.snapshotCh, "snapshot") - go r.processCh(r.chunkCh, "chunk") - go r.processCh(r.blockCh, "light block") - go r.processCh(r.paramsCh, "consensus params") - go r.processPeerUpdates() + go r.processCh(ctx, r.snapshotCh, "snapshot") + go r.processCh(ctx, r.chunkCh, "chunk") + go r.processCh(ctx, r.blockCh, "light block") + go r.processCh(ctx, r.paramsCh, "consensus params") + go r.processPeerUpdates(ctx) return nil } @@ -232,14 +232,7 @@ func (r *Reactor) OnStop() { // p2p Channels should execute Close(). close(r.closeCh) - // Wait for all p2p Channels to be closed before returning. This ensures we - // can easily reason about synchronization of all p2p Channels and ensure no - // panics will occur. <-r.peerUpdates.Done() - <-r.snapshotCh.Done() - <-r.chunkCh.Done() - <-r.blockCh.Done() - <-r.paramsCh.Done() } // Sync runs a state sync, fetching snapshots and providing chunks to the @@ -273,7 +266,7 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { r.stateProvider, r.snapshotCh.Out, r.chunkCh.Out, - r.snapshotCh.Done(), + ctx.Done(), r.tempDir, r.metrics, ) @@ -380,6 +373,8 @@ func (r *Reactor) backfill( go func() { for { select { + case <-ctx.Done(): + return case height := <-queue.nextHeight(): // pop the next peer of the list to send a request to peer := r.peers.Pop(ctx) @@ -815,11 +810,11 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err // encountered during message execution will result in a PeerError being sent on // the respective channel. When the reactor is stopped, we will catch the signal // and close the p2p Channel gracefully. -func (r *Reactor) processCh(ch *p2p.Channel, chName string) { - defer ch.Close() - +func (r *Reactor) processCh(ctx context.Context, ch *p2p.Channel, chName string) { for { select { + case <-ctx.Done(): + return case envelope := <-ch.In: if err := r.handleMessage(ch.ID, envelope); err != nil { r.Logger.Error("failed to process message", @@ -883,11 +878,13 @@ func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) { // processPeerUpdates initiates a blocking process where we listen for and handle // PeerUpdate messages. When the reactor is stopped, we will catch the signal and // close the p2p PeerUpdatesCh gracefully. -func (r *Reactor) processPeerUpdates() { +func (r *Reactor) processPeerUpdates(ctx context.Context) { defer r.peerUpdates.Close() for { select { + case <-ctx.Done(): + return case peerUpdate := <-r.peerUpdates.Updates(): r.processPeerUpdate(peerUpdate) diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 8dc2d6038..82ec0f68d 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -172,7 +172,7 @@ func setup( stateProvider, rts.snapshotOutCh, rts.chunkOutCh, - rts.snapshotChannel.Done(), + ctx.Done(), "", rts.reactor.metrics, ) diff --git a/libs/cli/helper.go b/libs/cli/helper.go index 37fe34fc9..76f3c9043 100644 --- a/libs/cli/helper.go +++ b/libs/cli/helper.go @@ -69,7 +69,10 @@ func RunCaptureWithArgs(cmd Executable, args []string, env map[string]string) (s var buf bytes.Buffer // io.Copy will end when we call reader.Close() below io.Copy(&buf, reader) //nolint:errcheck //ignore error - stdC <- buf.String() + select { + case <-cmd.Context().Done(): + case stdC <- buf.String(): + } }() return &stdC } diff --git a/libs/cli/setup.go b/libs/cli/setup.go index e4955dcf4..8e11bac93 100644 --- a/libs/cli/setup.go +++ b/libs/cli/setup.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "os" "path/filepath" @@ -22,6 +23,7 @@ const ( // wrap if desired before the test type Executable interface { Execute() error + Context() context.Context } // PrepareBaseCmd is meant for tendermint and other servers diff --git a/libs/os/os.go b/libs/os/os.go index 02b98c52a..0fd50da5b 100644 --- a/libs/os/os.go +++ b/libs/os/os.go @@ -1,6 +1,7 @@ package os import ( + "context" "errors" "fmt" "io" @@ -15,13 +16,14 @@ type logger interface { // TrapSignal catches SIGTERM and SIGINT, executes the cleanup function, // and exits with code 0. -func TrapSignal(logger logger, cb func()) { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) +func TrapSignal(ctx context.Context, logger logger, cb func()) { + opctx, opcancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) go func() { - sig := <-c - logger.Info(fmt.Sprintf("captured %v, exiting...", sig)) + defer opcancel() + defer opcancel() + <-opctx.Done() + logger.Info("captured signal, exiting...") if cb != nil { cb() } diff --git a/libs/os/os_test.go b/libs/os/os_test.go index 22d739ad7..eec3e723b 100644 --- a/libs/os/os_test.go +++ b/libs/os/os_test.go @@ -2,6 +2,7 @@ package os_test import ( "bytes" + "context" "fmt" "os" "os/exec" @@ -108,7 +109,10 @@ func (ml mockLogger) Info(msg string, keyvals ...interface{}) {} func killer() { logger := mockLogger{} - tmos.TrapSignal(logger, func() { _, _ = fmt.Fprintf(os.Stderr, "exiting") }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmos.TrapSignal(ctx, logger, func() { _, _ = fmt.Fprintf(os.Stderr, "exiting") }) time.Sleep(1 * time.Second) p, err := os.FindProcess(os.Getpid()) diff --git a/light/client.go b/light/client.go index cc606f496..866de7627 100644 --- a/light/client.go +++ b/light/client.go @@ -168,7 +168,8 @@ func NewClient( primary provider.Provider, witnesses []provider.Provider, trustedStore store.Store, - options ...Option) (*Client, error) { + options ...Option, +) (*Client, error) { // Check whether the trusted store already has a trusted block. If so, then create // a new client from the trusted store instead of the trust options. @@ -1023,7 +1024,11 @@ func (c *Client) findNewPrimary(ctx context.Context, height int64, remove bool) defer wg.Done() lb, err := c.witnesses[witnessIndex].LightBlock(subctx, height) - witnessResponsesC <- witnessResponse{lb, witnessIndex, err} + select { + case witnessResponsesC <- witnessResponse{lb, witnessIndex, err}: + case <-ctx.Done(): + } + }(index, witnessResponsesC) } diff --git a/light/client_test.go b/light/client_test.go index c7c974ee5..5b97d3394 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -1030,63 +1030,4 @@ func TestClientEnsureValidHeadersAndValSets(t *testing.T) { mockBadNode.AssertExpectations(t) }) } - -} - -func TestClientHandlesContexts(t *testing.T) { - mockNode := &provider_mocks.Provider{} - mockNode.On("LightBlock", - mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == nil }), - int64(1)).Return(l1, nil) - mockNode.On("LightBlock", - mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == context.DeadlineExceeded }), - mock.Anything).Return(nil, context.DeadlineExceeded) - - mockNode.On("LightBlock", - mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == context.Canceled }), - mock.Anything).Return(nil, context.Canceled) - - // instantiate the light client with a timeout - ctxTimeOut, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) - defer cancel() - _, err := light.NewClient( - ctxTimeOut, - chainID, - trustOptions, - mockNode, - []provider.Provider{mockNode, mockNode}, - dbs.New(dbm.NewMemDB()), - ) - require.Error(t, ctxTimeOut.Err()) - require.Error(t, err) - require.True(t, errors.Is(err, context.DeadlineExceeded)) - - // instantiate the client for real - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockNode, - []provider.Provider{mockNode, mockNode}, - dbs.New(dbm.NewMemDB()), - ) - require.NoError(t, err) - - // verify a block with a timeout - ctxTimeOutBlock, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) - defer cancel() - _, err = c.VerifyLightBlockAtHeight(ctxTimeOutBlock, 100, bTime.Add(100*time.Minute)) - require.Error(t, ctxTimeOutBlock.Err()) - require.Error(t, err) - require.True(t, errors.Is(err, context.DeadlineExceeded)) - - // verify a block with a cancel - ctxCancel, cancel := context.WithCancel(ctx) - cancel() - _, err = c.VerifyLightBlockAtHeight(ctxCancel, 100, bTime.Add(100*time.Minute)) - require.Error(t, ctxCancel.Err()) - require.Error(t, err) - require.True(t, errors.Is(err, context.Canceled)) - mockNode.AssertExpectations(t) - } diff --git a/light/proxy/proxy.go b/light/proxy/proxy.go index f8c183308..9119d2f2a 100644 --- a/light/proxy/proxy.go +++ b/light/proxy/proxy.go @@ -57,6 +57,7 @@ func (p *Proxy) ListenAndServe(ctx context.Context) error { p.Listener = listener return rpcserver.Serve( + ctx, listener, mux, p.Logger, @@ -75,6 +76,7 @@ func (p *Proxy) ListenAndServeTLS(ctx context.Context, certFile, keyFile string) p.Listener = listener return rpcserver.ServeTLS( + ctx, listener, mux, certFile, diff --git a/light/rpc/client.go b/light/rpc/client.go index 0060b7b74..7496d60a1 100644 --- a/light/rpc/client.go +++ b/light/rpc/client.go @@ -48,7 +48,8 @@ type Client struct { prt *merkle.ProofRuntime keyPathFn KeyPathFunc - quitCh chan struct{} + closers []func() + quitCh chan struct{} } var _ rpcclient.Client = (*Client)(nil) @@ -103,7 +104,9 @@ func NewClient(next rpcclient.Client, lc LightClient, opts ...Option) *Client { func (c *Client) OnStart(ctx context.Context) error { if !c.next.IsRunning() { - return c.next.Start(ctx) + nctx, ncancel := context.WithCancel(ctx) + c.closers = append(c.closers, ncancel) + return c.next.Start(nctx) } go func() { @@ -115,10 +118,8 @@ func (c *Client) OnStart(ctx context.Context) error { } func (c *Client) OnStop() { - if c.next.IsRunning() { - if err := c.next.Stop(); err != nil { - c.Logger.Error("Error stopping on next", "err", err) - } + for _, closer := range c.closers { + closer() } } @@ -614,7 +615,10 @@ func (c *Client) RegisterOpDecoder(typ string, dec merkle.OpDecoder) { // a subscriber, but does not verify responses (UNSAFE)! // TODO: verify data func (c *Client) SubscribeWS(ctx *rpctypes.Context, query string) (*coretypes.ResultSubscribe, error) { - out, err := c.next.Subscribe(context.Background(), ctx.RemoteAddr(), query) + bctx, bcancel := context.WithCancel(context.Background()) + c.closers = append(c.closers, bcancel) + + out, err := c.next.Subscribe(bctx, ctx.RemoteAddr(), query) if err != nil { return nil, err } @@ -625,12 +629,12 @@ func (c *Client) SubscribeWS(ctx *rpctypes.Context, query string) (*coretypes.Re case resultEvent := <-out: // We should have a switch here that performs a validation // depending on the event's type. - ctx.WSConn.TryWriteRPCResponse( + ctx.WSConn.TryWriteRPCResponse(bctx, rpctypes.NewRPCSuccessResponse( rpctypes.JSONRPCStringID(fmt.Sprintf("%v#event", ctx.JSONReq.ID)), resultEvent, )) - case <-c.quitCh: + case <-bctx.Done(): return } } diff --git a/node/node.go b/node/node.go index b9cdea4e6..10ade3ce7 100644 --- a/node/node.go +++ b/node/node.go @@ -92,7 +92,9 @@ func newDefaultNode( return nil, fmt.Errorf("failed to load or gen node key %s: %w", cfg.NodeKeyFile(), err) } if cfg.Mode == config.ModeSeed { - return makeSeedNode(cfg, + return makeSeedNode( + ctx, + cfg, config.DefaultDBProvider, nodeKey, defaultGenesisDocProviderFunc(cfg), @@ -280,7 +282,7 @@ func makeNode( makeCloser(closers)) } - router, err := createRouter(logger, nodeMetrics.p2p, nodeInfo, nodeKey, + router, err := createRouter(ctx, logger, nodeMetrics.p2p, nodeInfo, nodeKey, peerManager, cfg, proxyApp) if err != nil { return nil, combineCloseError( @@ -288,14 +290,14 @@ func makeNode( makeCloser(closers)) } - mpReactor, mp, err := createMempoolReactor( + mpReactor, mp, err := createMempoolReactor(ctx, cfg, proxyApp, state, nodeMetrics.mempool, peerManager, router, logger, ) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) } - evReactor, evPool, err := createEvidenceReactor( + evReactor, evPool, err := createEvidenceReactor(ctx, cfg, dbProvider, stateDB, blockStore, peerManager, router, logger, ) if err != nil { @@ -324,7 +326,7 @@ func makeNode( // Create the blockchain reactor. Note, we do not start block sync if we're // doing a state sync first. - bcReactor, err := createBlockchainReactor( + bcReactor, err := createBlockchainReactor(ctx, logger, state, blockExec, blockStore, csReactor, peerManager, router, blockSync && !stateSync, nodeMetrics.consensus, ) @@ -350,7 +352,7 @@ func makeNode( channels := make(map[p2p.ChannelID]*p2p.Channel, len(ssChDesc)) for idx := range ssChDesc { chd := ssChDesc[idx] - ch, err := router.OpenChannel(chd) + ch, err := router.OpenChannel(ctx, chd) if err != nil { return nil, err } @@ -369,7 +371,7 @@ func makeNode( channels[statesync.ChunkChannel], channels[statesync.LightBlockChannel], channels[statesync.ParamsChannel], - peerManager.Subscribe(), + peerManager.Subscribe(ctx), stateStore, blockStore, cfg.StateSync.TempDir, @@ -378,7 +380,7 @@ func makeNode( var pexReactor service.Service if cfg.P2P.PexReactor { - pexReactor, err = createPEXReactor(logger, peerManager, router) + pexReactor, err = createPEXReactor(ctx, logger, peerManager, router) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) } @@ -441,7 +443,9 @@ func makeNode( } // makeSeedNode returns a new seed node, containing only p2p, pex reactor -func makeSeedNode(cfg *config.Config, +func makeSeedNode( + ctx context.Context, + cfg *config.Config, dbProvider config.DBProvider, nodeKey types.NodeKey, genesisDocProvider genesisDocProvider, @@ -476,7 +480,7 @@ func makeSeedNode(cfg *config.Config, closer) } - router, err := createRouter(logger, p2pMetrics, nodeInfo, nodeKey, + router, err := createRouter(ctx, logger, p2pMetrics, nodeInfo, nodeKey, peerManager, cfg, nil) if err != nil { return nil, combineCloseError( @@ -484,7 +488,7 @@ func makeSeedNode(cfg *config.Config, closer) } - pexReactor, err := createPEXReactor(logger, peerManager, router) + pexReactor, err := createPEXReactor(ctx, logger, peerManager, router) if err != nil { return nil, combineCloseError(err, closer) } @@ -510,12 +514,25 @@ func makeSeedNode(cfg *config.Config, // OnStart starts the Node. It implements service.Service. func (n *nodeImpl) OnStart(ctx context.Context) error { if n.config.RPC.PprofListenAddress != "" { - // this service is not cleaned up (I believe that we'd - // need to have another thread and a potentially a - // context to get this functionality.) + rpcCtx, rpcCancel := context.WithCancel(ctx) + srv := &http.Server{Addr: n.config.RPC.PprofListenAddress, Handler: nil} + go func() { + select { + case <-ctx.Done(): + sctx, scancel := context.WithTimeout(context.Background(), time.Second) + defer scancel() + _ = srv.Shutdown(sctx) + case <-rpcCtx.Done(): + } + }() + go func() { n.Logger.Info("Starting pprof server", "laddr", n.config.RPC.PprofListenAddress) - n.Logger.Error("pprof server error", "err", http.ListenAndServe(n.config.RPC.PprofListenAddress, nil)) + + if err := srv.ListenAndServe(); err != nil { + n.Logger.Error("pprof server error", "err", err) + rpcCancel() + } }() } @@ -538,7 +555,7 @@ func (n *nodeImpl) OnStart(ctx context.Context) error { if n.config.Instrumentation.Prometheus && n.config.Instrumentation.PrometheusListenAddr != "" { - n.prometheusSrv = n.startPrometheusServer(n.config.Instrumentation.PrometheusListenAddr) + n.prometheusSrv = n.startPrometheusServer(ctx, n.config.Instrumentation.PrometheusListenAddr) } // Start the transport. @@ -784,6 +801,7 @@ func (n *nodeImpl) startRPC(ctx context.Context) ([]net.Listener, error) { if n.config.RPC.IsTLSEnabled() { go func() { if err := rpcserver.ServeTLS( + ctx, listener, rootHandler, n.config.RPC.CertFile(), @@ -797,6 +815,7 @@ func (n *nodeImpl) startRPC(ctx context.Context) ([]net.Listener, error) { } else { go func() { if err := rpcserver.Serve( + ctx, listener, rootHandler, rpcLogger, @@ -815,7 +834,7 @@ func (n *nodeImpl) startRPC(ctx context.Context) ([]net.Listener, error) { // startPrometheusServer starts a Prometheus HTTP server, listening for metrics // collectors on addr. -func (n *nodeImpl) startPrometheusServer(addr string) *http.Server { +func (n *nodeImpl) startPrometheusServer(ctx context.Context, addr string) *http.Server { srv := &http.Server{ Addr: addr, Handler: promhttp.InstrumentMetricHandler( @@ -825,12 +844,25 @@ func (n *nodeImpl) startPrometheusServer(addr string) *http.Server { ), ), } + + promCtx, promCancel := context.WithCancel(ctx) go func() { - if err := srv.ListenAndServe(); err != http.ErrServerClosed { - // Error starting or closing listener: - n.Logger.Error("Prometheus HTTP server ListenAndServe", "err", err) + select { + case <-ctx.Done(): + sctx, scancel := context.WithTimeout(context.Background(), time.Second) + defer scancel() + _ = srv.Shutdown(sctx) + case <-promCtx.Done(): } }() + + go func() { + if err := srv.ListenAndServe(); err != nil { + n.Logger.Error("Prometheus HTTP server ListenAndServe", "err", err) + promCancel() + } + }() + return srv } diff --git a/node/node_test.go b/node/node_test.go index 16f8c44aa..d9806c9f1 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -528,7 +528,8 @@ func TestNodeNewSeedNode(t *testing.T) { nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) require.NoError(t, err) - ns, err := makeSeedNode(cfg, + ns, err := makeSeedNode(ctx, + cfg, config.DefaultDBProvider, nodeKey, defaultGenesisDocProviderFunc(cfg), diff --git a/node/public.go b/node/public.go index 87007bdfc..0d6f1d93e 100644 --- a/node/public.go +++ b/node/public.go @@ -68,7 +68,7 @@ func New( config.DefaultDBProvider, logger) case config.ModeSeed: - return makeSeedNode(conf, config.DefaultDBProvider, nodeKey, genProvider, logger) + return makeSeedNode(ctx, conf, config.DefaultDBProvider, nodeKey, genProvider, logger) default: return nil, fmt.Errorf("%q is not a valid mode", conf.Mode) } diff --git a/node/setup.go b/node/setup.go index ca47a9c25..910eefad6 100644 --- a/node/setup.go +++ b/node/setup.go @@ -190,6 +190,7 @@ func onlyValidatorIsUs(state sm.State, pubKey crypto.PubKey) bool { } func createMempoolReactor( + ctx context.Context, cfg *config.Config, proxyApp proxy.AppConns, state sm.State, @@ -201,7 +202,7 @@ func createMempoolReactor( logger = logger.With("module", "mempool") - ch, err := router.OpenChannel(mempool.GetChannelDescriptor(cfg.Mempool)) + ch, err := router.OpenChannel(ctx, mempool.GetChannelDescriptor(cfg.Mempool)) if err != nil { return nil, nil, err } @@ -222,7 +223,7 @@ func createMempoolReactor( peerManager, mp, ch, - peerManager.Subscribe(), + peerManager.Subscribe(ctx), ) if cfg.Consensus.WaitForTxs() { @@ -233,6 +234,7 @@ func createMempoolReactor( } func createEvidenceReactor( + ctx context.Context, cfg *config.Config, dbProvider config.DBProvider, stateDB dbm.DB, @@ -253,7 +255,7 @@ func createEvidenceReactor( return nil, nil, fmt.Errorf("creating evidence pool: %w", err) } - ch, err := router.OpenChannel(evidence.GetChannelDescriptor()) + ch, err := router.OpenChannel(ctx, evidence.GetChannelDescriptor()) if err != nil { return nil, nil, fmt.Errorf("creating evidence channel: %w", err) } @@ -261,7 +263,7 @@ func createEvidenceReactor( evidenceReactor := evidence.NewReactor( logger, ch, - peerManager.Subscribe(), + peerManager.Subscribe(ctx), evidencePool, ) @@ -269,6 +271,7 @@ func createEvidenceReactor( } func createBlockchainReactor( + ctx context.Context, logger log.Logger, state sm.State, blockExec *sm.BlockExecutor, @@ -282,12 +285,12 @@ func createBlockchainReactor( logger = logger.With("module", "blockchain") - ch, err := router.OpenChannel(blocksync.GetChannelDescriptor()) + ch, err := router.OpenChannel(ctx, blocksync.GetChannelDescriptor()) if err != nil { return nil, err } - peerUpdates := peerManager.Subscribe() + peerUpdates := peerManager.Subscribe(ctx) reactor, err := blocksync.NewReactor( logger, state.Copy(), blockExec, blockStore, csReactor, @@ -338,7 +341,7 @@ func createConsensusReactor( channels := make(map[p2p.ChannelID]*p2p.Channel, len(csChDesc)) for idx := range csChDesc { chd := csChDesc[idx] - ch, err := router.OpenChannel(chd) + ch, err := router.OpenChannel(ctx, chd) if err != nil { return nil, nil, err } @@ -353,7 +356,7 @@ func createConsensusReactor( channels[consensus.DataChannel], channels[consensus.VoteChannel], channels[consensus.VoteSetBitsChannel], - peerManager.Subscribe(), + peerManager.Subscribe(ctx), waitSync, consensus.ReactorMetrics(csMetrics), ) @@ -450,6 +453,7 @@ func createPeerManager( } func createRouter( + ctx context.Context, logger log.Logger, p2pMetrics *p2p.Metrics, nodeInfo types.NodeInfo, @@ -468,6 +472,7 @@ func createRouter( } return p2p.NewRouter( + ctx, p2pLogger, p2pMetrics, nodeInfo, @@ -480,17 +485,18 @@ func createRouter( } func createPEXReactor( + ctx context.Context, logger log.Logger, peerManager *p2p.PeerManager, router *p2p.Router, ) (service.Service, error) { - channel, err := router.OpenChannel(pex.ChannelDescriptor()) + channel, err := router.OpenChannel(ctx, pex.ChannelDescriptor()) if err != nil { return nil, err } - return pex.NewReactor(logger, peerManager, channel, peerManager.Subscribe()), nil + return pex.NewReactor(logger, peerManager, channel, peerManager.Subscribe(ctx)), nil } func makeNodeInfo( diff --git a/rpc/client/helpers.go b/rpc/client/helpers.go index 58e48dbba..c7908fec2 100644 --- a/rpc/client/helpers.go +++ b/rpc/client/helpers.go @@ -104,7 +104,6 @@ type RunState struct { mu sync.Mutex name string isRunning bool - quit chan struct{} } // NewRunState returns a new unstarted run state tracker with the given logging @@ -120,7 +119,7 @@ func NewRunState(name string, logger log.Logger) *RunState { } // Start sets the state to running, or reports an error. -func (r *RunState) Start() error { +func (r *RunState) Start(context.Context) error { r.mu.Lock() defer r.mu.Unlock() if r.isRunning { @@ -129,7 +128,6 @@ func (r *RunState) Start() error { } r.Logger.Info("starting client", "client", r.name) r.isRunning = true - r.quit = make(chan struct{}) return nil } @@ -143,27 +141,12 @@ func (r *RunState) Stop() error { } r.Logger.Info("stopping client", "client", r.name) r.isRunning = false - close(r.quit) return nil } -// SetLogger updates the log sink. -func (r *RunState) SetLogger(logger log.Logger) { - r.mu.Lock() - defer r.mu.Unlock() - r.Logger = logger -} - // IsRunning reports whether the state is running. func (r *RunState) IsRunning() bool { r.mu.Lock() defer r.mu.Unlock() return r.isRunning } - -// Quit returns a channel that is closed when a call to Stop succeeds. -func (r *RunState) Quit() <-chan struct{} { - r.mu.Lock() - defer r.mu.Unlock() - return r.quit -} diff --git a/rpc/client/http/ws.go b/rpc/client/http/ws.go index e4c2a14ed..dda8e4f46 100644 --- a/rpc/client/http/ws.go +++ b/rpc/client/http/ws.go @@ -86,17 +86,17 @@ func newWsEvents(remote string, wso WSOptions) (*wsEvents, error) { // resubscribe immediately w.redoSubscriptionsAfter(0 * time.Second) }) - w.ws.SetLogger(w.Logger) + w.ws.Logger = w.Logger return w, nil } // Start starts the websocket client and the event loop. func (w *wsEvents) Start(ctx context.Context) error { - if err := w.ws.Start(); err != nil { + if err := w.ws.Start(ctx); err != nil { return err } - go w.eventListener() + go w.eventListener(ctx) return nil } @@ -216,7 +216,7 @@ func isErrAlreadySubscribed(err error) bool { return strings.Contains(err.Error(), pubsub.ErrAlreadySubscribed.Error()) } -func (w *wsEvents) eventListener() { +func (w *wsEvents) eventListener(ctx context.Context) { for { select { case resp, ok := <-w.ws.ResponsesCh: @@ -258,11 +258,11 @@ func (w *wsEvents) eventListener() { if ok { select { case out.res <- *result: - case <-w.Quit(): + case <-ctx.Done(): return } } - case <-w.Quit(): + case <-ctx.Done(): return } } diff --git a/rpc/client/interface.go b/rpc/client/interface.go index e23d2f563..9b2a600cc 100644 --- a/rpc/client/interface.go +++ b/rpc/client/interface.go @@ -37,9 +37,6 @@ type Client interface { // Start the client. Start must report an error if the client is running. Start(context.Context) error - // Stop the client. Stop must report an error if the client is not running. - Stop() error - // IsRunning reports whether the client is running. IsRunning() bool diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index b2526176a..ff6c0d57c 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -40,7 +40,7 @@ func getHTTPClient(t *testing.T, conf *config.Config) *rpchttp.HTTP { c, err := rpchttp.NewWithClient(rpcAddr, http.DefaultClient) require.NoError(t, err) - c.SetLogger(log.TestingLogger()) + c.Logger = log.TestingLogger() t.Cleanup(func() { if c.IsRunning() { require.NoError(t, c.Stop()) @@ -59,7 +59,7 @@ func getHTTPClientWithTimeout(t *testing.T, conf *config.Config, timeout time.Du c, err := rpchttp.NewWithClient(rpcAddr, http.DefaultClient) require.NoError(t, err) - c.SetLogger(log.TestingLogger()) + c.Logger = log.TestingLogger() t.Cleanup(func() { http.DefaultClient.Timeout = 0 if c.IsRunning() { @@ -471,16 +471,13 @@ func TestClientMethodCalls(t *testing.T) { t.Run("Events", func(t *testing.T) { // start for this test it if it wasn't already running if !c.IsRunning() { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // if so, then we start it, listen, and stop it. err := c.Start(ctx) require.Nil(t, err) - t.Cleanup(func() { - if err := c.Stop(); err != nil { - t.Error(err) - } - }) } - t.Run("Header", func(t *testing.T) { evt, err := client.WaitForOneEvent(c, types.EventNewBlockHeaderValue, waitForEventTimeout) require.Nil(t, err, "%d: %+v", i, err) diff --git a/rpc/jsonrpc/client/ws_client.go b/rpc/jsonrpc/client/ws_client.go index 8d8f9d18d..51891712f 100644 --- a/rpc/jsonrpc/client/ws_client.go +++ b/rpc/jsonrpc/client/ws_client.go @@ -24,6 +24,7 @@ type WSOptions struct { ReadWait time.Duration // deadline for any read op WriteWait time.Duration // deadline for any write op PingPeriod time.Duration // frequency with which pings are sent + SkipMetrics bool // do not keep metrics for ping/pong latency } // DefaultWSOptions returns default WS options. @@ -117,8 +118,6 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e Address: parsedURL.GetTrimmedHostWithPath(), Dialer: dialFn, Endpoint: endpoint, - PingPongLatencyTimer: metrics.NewTimer(), - maxReconnectAttempts: opts.MaxReconnectAttempts, readWait: opts.ReadWait, writeWait: opts.WriteWait, @@ -127,6 +126,14 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e // sentIDs: make(map[types.JSONRPCIntID]bool), } + + switch opts.SkipMetrics { + case true: + c.PingPongLatencyTimer = metrics.NilTimer{} + case false: + c.PingPongLatencyTimer = metrics.NewTimer() + } + return c, nil } @@ -143,8 +150,8 @@ func (c *WSClient) String() string { } // Start dials the specified service address and starts the I/O routines. -func (c *WSClient) Start() error { - if err := c.RunState.Start(); err != nil { +func (c *WSClient) Start(ctx context.Context) error { + if err := c.RunState.Start(ctx); err != nil { return err } err := c.dial() @@ -162,8 +169,8 @@ func (c *WSClient) Start() error { // channel is unbuffered. c.backlog = make(chan rpctypes.RPCRequest, 1) - c.startReadWriteRoutines() - go c.reconnectRoutine() + c.startReadWriteRoutines(ctx) + go c.reconnectRoutine(ctx) return nil } @@ -173,6 +180,7 @@ func (c *WSClient) Stop() error { if err := c.RunState.Stop(); err != nil { return err } + // only close user-facing channels when we can't write to them c.wg.Wait() close(c.ResponsesCh) @@ -253,7 +261,7 @@ func (c *WSClient) dial() error { // reconnect tries to redial up to maxReconnectAttempts with exponential // backoff. -func (c *WSClient) reconnect() error { +func (c *WSClient) reconnect(ctx context.Context) error { attempt := uint(0) c.mtx.Lock() @@ -265,13 +273,21 @@ func (c *WSClient) reconnect() error { c.mtx.Unlock() }() + timer := time.NewTimer(0) + defer timer.Stop() + for { // nolint:gosec // G404: Use of weak random number generator jitter := time.Duration(mrand.Float64() * float64(time.Second)) // 1s == (1e9 ns) backoffDuration := jitter + ((1 << attempt) * time.Second) c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration) - time.Sleep(backoffDuration) + timer.Reset(backoffDuration) + select { + case <-ctx.Done(): + return nil + case <-timer.C: + } err := c.dial() if err != nil { @@ -292,11 +308,11 @@ func (c *WSClient) reconnect() error { } } -func (c *WSClient) startReadWriteRoutines() { +func (c *WSClient) startReadWriteRoutines(ctx context.Context) { c.wg.Add(2) c.readRoutineQuit = make(chan struct{}) - go c.readRoutine() - go c.writeRoutine() + go c.readRoutine(ctx) + go c.writeRoutine(ctx) } func (c *WSClient) processBacklog() error { @@ -320,13 +336,15 @@ func (c *WSClient) processBacklog() error { return nil } -func (c *WSClient) reconnectRoutine() { +func (c *WSClient) reconnectRoutine(ctx context.Context) { for { select { + case <-ctx.Done(): + return case originalError := <-c.reconnectAfter: // wait until writeRoutine and readRoutine finish c.wg.Wait() - if err := c.reconnect(); err != nil { + if err := c.reconnect(ctx); err != nil { c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError) if err = c.Stop(); err != nil { c.Logger.Error("failed to stop conn", "error", err) @@ -338,6 +356,8 @@ func (c *WSClient) reconnectRoutine() { LOOP: for { select { + case <-ctx.Done(): + return case <-c.reconnectAfter: default: break LOOP @@ -345,18 +365,15 @@ func (c *WSClient) reconnectRoutine() { } err := c.processBacklog() if err == nil { - c.startReadWriteRoutines() + c.startReadWriteRoutines(ctx) } - - case <-c.Quit(): - return } } } // The client ensures that there is at most one writer to a connection by // executing all writes from this goroutine. -func (c *WSClient) writeRoutine() { +func (c *WSClient) writeRoutine(ctx context.Context) { var ticker *time.Ticker if c.pingPeriod > 0 { // ticker with a predefined period @@ -408,7 +425,7 @@ func (c *WSClient) writeRoutine() { c.Logger.Debug("sent ping") case <-c.readRoutineQuit: return - case <-c.Quit(): + case <-ctx.Done(): if err := c.conn.WriteMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), @@ -422,7 +439,7 @@ func (c *WSClient) writeRoutine() { // The client ensures that there is at most one reader to a connection by // executing all reads from this goroutine. -func (c *WSClient) readRoutine() { +func (c *WSClient) readRoutine(ctx context.Context) { defer func() { c.conn.Close() // err != nil { @@ -494,7 +511,8 @@ func (c *WSClient) readRoutine() { c.Logger.Info("got response", "id", response.ID, "result", response.Result) select { - case <-c.Quit(): + case <-ctx.Done(): + return case c.ResponsesCh <- response: } } diff --git a/rpc/jsonrpc/client/ws_client_test.go b/rpc/jsonrpc/client/ws_client_test.go index 208313e79..d1d6c1fed 100644 --- a/rpc/jsonrpc/client/ws_client_test.go +++ b/rpc/jsonrpc/client/ws_client_test.go @@ -5,10 +5,11 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "sync" + "runtime" "testing" "time" + "github.com/fortytw2/leaktest" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" @@ -64,25 +65,26 @@ func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func TestWSClientReconnectsAfterReadFailure(t *testing.T) { - var wg sync.WaitGroup + t.Cleanup(leaktest.Check(t)) // start server h := &myHandler{} s := httptest.NewServer(h) defer s.Close() - c := startClient(t, "//"+s.Listener.Addr().String()) - defer c.Stop() // nolint:errcheck // ignore for tests + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - wg.Add(1) - go callWgDoneOnResult(t, c, &wg) + c := startClient(ctx, t, "//"+s.Listener.Addr().String()) + + go handleResponses(ctx, t, c) h.mtx.Lock() h.closeConnAfterRead = true h.mtx.Unlock() // results in WS read error, no send retry because write succeeded - call(t, "a", c) + call(ctx, t, "a", c) // expect to reconnect almost immediately time.Sleep(10 * time.Millisecond) @@ -91,23 +93,23 @@ func TestWSClientReconnectsAfterReadFailure(t *testing.T) { h.mtx.Unlock() // should succeed - call(t, "b", c) - - wg.Wait() + call(ctx, t, "b", c) } func TestWSClientReconnectsAfterWriteFailure(t *testing.T) { - var wg sync.WaitGroup + t.Cleanup(leaktest.Check(t)) // start server h := &myHandler{} s := httptest.NewServer(h) + defer s.Close() - c := startClient(t, "//"+s.Listener.Addr().String()) - defer c.Stop() // nolint:errcheck // ignore for tests + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - wg.Add(2) - go callWgDoneOnResult(t, c, &wg) + c := startClient(ctx, t, "//"+s.Listener.Addr().String()) + + go handleResponses(ctx, t, c) // hacky way to abort the connection before write if err := c.conn.Close(); err != nil { @@ -115,30 +117,32 @@ func TestWSClientReconnectsAfterWriteFailure(t *testing.T) { } // results in WS write error, the client should resend on reconnect - call(t, "a", c) + call(ctx, t, "a", c) // expect to reconnect almost immediately time.Sleep(10 * time.Millisecond) // should succeed - call(t, "b", c) - - wg.Wait() + call(ctx, t, "b", c) } func TestWSClientReconnectFailure(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + // start server h := &myHandler{} s := httptest.NewServer(h) - c := startClient(t, "//"+s.Listener.Addr().String()) - defer c.Stop() // nolint:errcheck // ignore for tests + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := startClient(ctx, t, "//"+s.Listener.Addr().String()) go func() { for { select { case <-c.ResponsesCh: - case <-c.Quit(): + case <-ctx.Done(): return } } @@ -152,9 +156,9 @@ func TestWSClientReconnectFailure(t *testing.T) { // results in WS write error // provide timeout to avoid blocking - ctx, cancel := context.WithTimeout(context.Background(), wsCallTimeout) + cctx, cancel := context.WithTimeout(ctx, wsCallTimeout) defer cancel() - if err := c.Call(ctx, "a", make(map[string]interface{})); err != nil { + if err := c.Call(cctx, "a", make(map[string]interface{})); err != nil { t.Error(err) } @@ -164,7 +168,7 @@ func TestWSClientReconnectFailure(t *testing.T) { done := make(chan struct{}) go func() { // client should block on this - call(t, "b", c) + call(ctx, t, "b", c) close(done) }() @@ -178,44 +182,68 @@ func TestWSClientReconnectFailure(t *testing.T) { } func TestNotBlockingOnStop(t *testing.T) { - timeout := 2 * time.Second + t.Cleanup(leaktest.Check(t)) + + timeout := 3 * time.Second s := httptest.NewServer(&myHandler{}) - c := startClient(t, "//"+s.Listener.Addr().String()) - c.Call(context.Background(), "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests + defer s.Close() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := startClient(ctx, t, "//"+s.Listener.Addr().String()) + c.Call(ctx, "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests // Let the readRoutine get around to blocking time.Sleep(time.Second) passCh := make(chan struct{}) go func() { // Unless we have a non-blocking write to ResponsesCh from readRoutine // this blocks forever ont the waitgroup - err := c.Stop() - require.NoError(t, err) - passCh <- struct{}{} + cancel() + require.NoError(t, c.Stop()) + select { + case <-ctx.Done(): + case passCh <- struct{}{}: + } }() + + runtime.Gosched() // hacks: force context switch + select { case <-passCh: // Pass case <-time.After(timeout): - t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?", - timeout.Seconds()) + if c.IsRunning() { + t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?", + timeout.Seconds()) + } } } -func startClient(t *testing.T, addr string) *WSClient { - c, err := NewWS(addr, "/websocket") +func startClient(ctx context.Context, t *testing.T, addr string) *WSClient { + t.Helper() + opts := DefaultWSOptions() + opts.SkipMetrics = true + c, err := NewWSWithOptions(addr, "/websocket", opts) + require.Nil(t, err) - err = c.Start() + err = c.Start(ctx) require.Nil(t, err) - c.SetLogger(log.TestingLogger()) + c.Logger = log.TestingLogger() return c } -func call(t *testing.T, method string, c *WSClient) { - err := c.Call(context.Background(), method, make(map[string]interface{})) - require.NoError(t, err) +func call(ctx context.Context, t *testing.T, method string, c *WSClient) { + t.Helper() + + err := c.Call(ctx, method, make(map[string]interface{})) + if ctx.Err() == nil { + require.NoError(t, err) + } } -func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) { +func handleResponses(ctx context.Context, t *testing.T, c *WSClient) { + t.Helper() + for { select { case resp := <-c.ResponsesCh: @@ -224,9 +252,9 @@ func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) { return } if resp.Result != nil { - wg.Done() + return } - case <-c.Quit(): + case <-ctx.Done(): return } } diff --git a/rpc/jsonrpc/jsonrpc_test.go b/rpc/jsonrpc/jsonrpc_test.go index 5013590b6..3426c48b8 100644 --- a/rpc/jsonrpc/jsonrpc_test.go +++ b/rpc/jsonrpc/jsonrpc_test.go @@ -35,10 +35,6 @@ const ( testVal = "acbd" ) -var ( - ctx = context.Background() -) - type ResultEcho struct { Value string `json:"value"` } @@ -85,13 +81,16 @@ func EchoDataBytesResult(ctx *rpctypes.Context, v tmbytes.HexBytes) (*ResultEcho } func TestMain(m *testing.M) { - setup() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + setup(ctx) code := m.Run() os.Exit(code) } // launch unix and tcp servers -func setup() { +func setup(ctx context.Context) { logger := log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false) cmd := exec.Command("rm", "-f", unixSocket) @@ -115,7 +114,7 @@ func setup() { panic(err) } go func() { - if err := server.Serve(listener1, mux, tcpLogger, config); err != nil { + if err := server.Serve(ctx, listener1, mux, tcpLogger, config); err != nil { panic(err) } }() @@ -131,7 +130,7 @@ func setup() { panic(err) } go func() { - if err := server.Serve(listener2, mux2, unixLogger, config); err != nil { + if err := server.Serve(ctx, listener2, mux2, unixLogger, config); err != nil { panic(err) } }() @@ -140,7 +139,7 @@ func setup() { time.Sleep(time.Second * 2) } -func echoViaHTTP(cl client.Caller, val string) (string, error) { +func echoViaHTTP(ctx context.Context, cl client.Caller, val string) (string, error) { params := map[string]interface{}{ "arg": val, } @@ -151,7 +150,7 @@ func echoViaHTTP(cl client.Caller, val string) (string, error) { return result.Value, nil } -func echoIntViaHTTP(cl client.Caller, val int) (int, error) { +func echoIntViaHTTP(ctx context.Context, cl client.Caller, val int) (int, error) { params := map[string]interface{}{ "arg": val, } @@ -162,7 +161,7 @@ func echoIntViaHTTP(cl client.Caller, val int) (int, error) { return result.Value, nil } -func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) { +func echoBytesViaHTTP(ctx context.Context, cl client.Caller, bytes []byte) ([]byte, error) { params := map[string]interface{}{ "arg": bytes, } @@ -173,7 +172,7 @@ func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) { return result.Value, nil } -func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) { +func echoDataBytesViaHTTP(ctx context.Context, cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) { params := map[string]interface{}{ "arg": bytes, } @@ -184,24 +183,24 @@ func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.Hex return result.Value, nil } -func testWithHTTPClient(t *testing.T, cl client.HTTPClient) { +func testWithHTTPClient(ctx context.Context, t *testing.T, cl client.HTTPClient) { val := testVal - got, err := echoViaHTTP(cl, val) + got, err := echoViaHTTP(ctx, cl, val) require.Nil(t, err) assert.Equal(t, got, val) val2 := randBytes(t) - got2, err := echoBytesViaHTTP(cl, val2) + got2, err := echoBytesViaHTTP(ctx, cl, val2) require.Nil(t, err) assert.Equal(t, got2, val2) val3 := tmbytes.HexBytes(randBytes(t)) - got3, err := echoDataBytesViaHTTP(cl, val3) + got3, err := echoDataBytesViaHTTP(ctx, cl, val3) require.Nil(t, err) assert.Equal(t, got3, val3) val4 := mrand.Intn(10000) - got4, err := echoIntViaHTTP(cl, val4) + got4, err := echoIntViaHTTP(ctx, cl, val4) require.Nil(t, err) assert.Equal(t, got4, val4) } @@ -265,55 +264,70 @@ func testWithWSClient(t *testing.T, cl *client.WSClient) { //------------- func TestServersAndClientsBasic(t *testing.T) { + bctx, cancel := context.WithCancel(context.Background()) + defer cancel() + serverAddrs := [...]string{tcpAddr, unixAddr} for _, addr := range serverAddrs { - cl1, err := client.NewURI(addr) - require.Nil(t, err) - fmt.Printf("=== testing server on %s using URI client", addr) - testWithHTTPClient(t, cl1) + t.Run(addr, func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() - cl2, err := client.New(addr) - require.Nil(t, err) - fmt.Printf("=== testing server on %s using JSONRPC client", addr) - testWithHTTPClient(t, cl2) + cl1, err := client.NewURI(addr) + require.Nil(t, err) + fmt.Printf("=== testing server on %s using URI client", addr) + testWithHTTPClient(ctx, t, cl1) - cl3, err := client.NewWS(addr, websocketEndpoint) - require.Nil(t, err) - cl3.SetLogger(log.TestingLogger()) - err = cl3.Start() - require.Nil(t, err) - fmt.Printf("=== testing server on %s using WS client", addr) - testWithWSClient(t, cl3) - err = cl3.Stop() - require.NoError(t, err) + cl2, err := client.New(addr) + require.Nil(t, err) + fmt.Printf("=== testing server on %s using JSONRPC client", addr) + testWithHTTPClient(ctx, t, cl2) + + cl3, err := client.NewWS(addr, websocketEndpoint) + require.Nil(t, err) + cl3.Logger = log.TestingLogger() + err = cl3.Start(ctx) + require.Nil(t, err) + fmt.Printf("=== testing server on %s using WS client", addr) + testWithWSClient(t, cl3) + cancel() + }) } } func TestHexStringArg(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cl, err := client.NewURI(tcpAddr) require.Nil(t, err) // should NOT be handled as hex val := "0xabc" - got, err := echoViaHTTP(cl, val) + got, err := echoViaHTTP(ctx, cl, val) require.Nil(t, err) assert.Equal(t, got, val) } func TestQuotedStringArg(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() cl, err := client.NewURI(tcpAddr) require.Nil(t, err) // should NOT be unquoted val := "\"abc\"" - got, err := echoViaHTTP(cl, val) + got, err := echoViaHTTP(ctx, cl, val) require.Nil(t, err) assert.Equal(t, got, val) } func TestWSNewWSRPCFunc(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cl, err := client.NewWS(tcpAddr, websocketEndpoint) require.Nil(t, err) - cl.SetLogger(log.TestingLogger()) - err = cl.Start() + cl.Logger = log.TestingLogger() + err = cl.Start(ctx) require.Nil(t, err) t.Cleanup(func() { if err := cl.Stop(); err != nil { @@ -340,11 +354,14 @@ func TestWSNewWSRPCFunc(t *testing.T) { } func TestWSHandlesArrayParams(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cl, err := client.NewWS(tcpAddr, websocketEndpoint) require.Nil(t, err) - cl.SetLogger(log.TestingLogger()) - err = cl.Start() - require.Nil(t, err) + + cl.Logger = log.TestingLogger() + require.Nil(t, cl.Start(ctx)) t.Cleanup(func() { if err := cl.Stop(); err != nil { t.Error(err) @@ -370,10 +387,13 @@ func TestWSHandlesArrayParams(t *testing.T) { // TestWSClientPingPong checks that a client & server exchange pings // & pongs so connection stays alive. func TestWSClientPingPong(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cl, err := client.NewWS(tcpAddr, websocketEndpoint) require.Nil(t, err) - cl.SetLogger(log.TestingLogger()) - err = cl.Start() + cl.Logger = log.TestingLogger() + err = cl.Start(ctx) require.Nil(t, err) t.Cleanup(func() { if err := cl.Stop(); err != nil { diff --git a/rpc/jsonrpc/server/http_server.go b/rpc/jsonrpc/server/http_server.go index 49e1e510e..5d6b3a355 100644 --- a/rpc/jsonrpc/server/http_server.go +++ b/rpc/jsonrpc/server/http_server.go @@ -3,6 +3,7 @@ package server import ( "bufio" + "context" "encoding/json" "errors" "fmt" @@ -50,7 +51,13 @@ func DefaultConfig() *Config { // body size to config.MaxBodyBytes. // // NOTE: This function blocks - you may want to call it in a go-routine. -func Serve(listener net.Listener, handler http.Handler, logger log.Logger, config *Config) error { +func Serve( + ctx context.Context, + listener net.Listener, + handler http.Handler, + logger log.Logger, + config *Config, +) error { logger.Info(fmt.Sprintf("Starting RPC HTTP server on %s", listener.Addr())) s := &http.Server{ Handler: RecoverAndLogHandler(maxBytesHandler{h: handler, n: config.MaxBodyBytes}, logger), @@ -58,9 +65,23 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi WriteTimeout: config.WriteTimeout, MaxHeaderBytes: config.MaxHeaderBytes, } - err := s.Serve(listener) - logger.Info("RPC HTTP server stopped", "err", err) - return err + sig := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + sctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _ = s.Shutdown(sctx) + case <-sig: + } + }() + + if err := s.Serve(listener); err != nil { + logger.Info("RPC HTTP server stopped", "err", err) + close(sig) + return err + } + return nil } // Serve creates a http.Server and calls ServeTLS with the given listener, @@ -69,6 +90,7 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi // // NOTE: This function blocks - you may want to call it in a go-routine. func ServeTLS( + ctx context.Context, listener net.Listener, handler http.Handler, certFile, keyFile string, @@ -83,10 +105,23 @@ func ServeTLS( WriteTimeout: config.WriteTimeout, MaxHeaderBytes: config.MaxHeaderBytes, } - err := s.ServeTLS(listener, certFile, keyFile) + sig := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + sctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _ = s.Shutdown(sctx) + case <-sig: + } + }() - logger.Error("RPC HTTPS server stopped", "err", err) - return err + if err := s.ServeTLS(listener, certFile, keyFile); err != nil { + logger.Error("RPC HTTPS server stopped", "err", err) + close(sig) + return err + } + return nil } // WriteRPCResponseHTTPError marshals res as JSON (with indent) and writes it diff --git a/rpc/jsonrpc/server/http_server_test.go b/rpc/jsonrpc/server/http_server_test.go index ff2776bb4..983e2eb65 100644 --- a/rpc/jsonrpc/server/http_server_test.go +++ b/rpc/jsonrpc/server/http_server_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/tls" "errors" "fmt" @@ -27,6 +28,9 @@ type sampleResult struct { func TestMaxOpenConnections(t *testing.T) { const max = 5 // max simultaneous connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Start the server. var open int32 mux := http.NewServeMux() @@ -42,7 +46,7 @@ func TestMaxOpenConnections(t *testing.T) { l, err := Listen("tcp://127.0.0.1:0", max) require.NoError(t, err) defer l.Close() - go Serve(l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests + go Serve(ctx, l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests // Make N GET calls to the server. attempts := max * 2 @@ -80,10 +84,12 @@ func TestServeTLS(t *testing.T) { fmt.Fprint(w, "some body") }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + chErr := make(chan error, 1) go func() { - // FIXME This goroutine leaks - chErr <- ServeTLS(ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig()) + chErr <- ServeTLS(ctx, ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig()) }() select { diff --git a/rpc/jsonrpc/server/ws_handler.go b/rpc/jsonrpc/server/ws_handler.go index 2271d03f8..a6fe0c594 100644 --- a/rpc/jsonrpc/server/ws_handler.go +++ b/rpc/jsonrpc/server/ws_handler.go @@ -87,14 +87,16 @@ func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Requ // register connection logger := wm.logger.With("remote", wsConn.RemoteAddr()) - con := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...) - wm.logger.Info("New websocket connection", "remote", con.remoteAddr) - err = con.Start() // BLOCKING - if err != nil { + conn := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...) + wm.logger.Info("New websocket connection", "remote", conn.remoteAddr) + + // starting the conn is blocking + if err = conn.Start(r.Context()); err != nil { wm.logger.Error("Failed to start connection", "err", err) return } - if err := con.Stop(); err != nil { + + if err := conn.Stop(); err != nil { wm.logger.Error("error while stopping connection", "error", err) } } @@ -220,16 +222,16 @@ func ReadLimit(readLimit int64) func(*wsConnection) { } // Start starts the client service routines and blocks until there is an error. -func (wsc *wsConnection) Start() error { - if err := wsc.RunState.Start(); err != nil { +func (wsc *wsConnection) Start(ctx context.Context) error { + if err := wsc.RunState.Start(ctx); err != nil { return err } wsc.writeChan = make(chan rpctypes.RPCResponse, wsc.writeChanCapacity) // Read subscriptions/unsubscriptions to events - go wsc.readRoutine() + go wsc.readRoutine(ctx) // Write responses, BLOCKING. - wsc.writeRoutine() + wsc.writeRoutine(ctx) return nil } @@ -259,8 +261,6 @@ func (wsc *wsConnection) GetRemoteAddr() string { // It implements WSRPCConnection. It is Goroutine-safe. func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) error { select { - case <-wsc.Quit(): - return errors.New("connection was stopped") case <-ctx.Done(): return ctx.Err() case wsc.writeChan <- resp: @@ -271,9 +271,9 @@ func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPC // TryWriteRPCResponse attempts to push a response to the writeChan, but does // not block. // It implements WSRPCConnection. It is Goroutine-safe -func (wsc *wsConnection) TryWriteRPCResponse(resp rpctypes.RPCResponse) bool { +func (wsc *wsConnection) TryWriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) bool { select { - case <-wsc.Quit(): + case <-ctx.Done(): return false case wsc.writeChan <- resp: return true @@ -293,7 +293,7 @@ func (wsc *wsConnection) Context() context.Context { } // Read from the socket and subscribe to or unsubscribe from events -func (wsc *wsConnection) readRoutine() { +func (wsc *wsConnection) readRoutine(ctx context.Context) { // readRoutine will block until response is written or WS connection is closed writeCtx := context.Background() @@ -307,7 +307,7 @@ func (wsc *wsConnection) readRoutine() { if err := wsc.WriteRPCResponse(writeCtx, rpctypes.RPCInternalError(rpctypes.JSONRPCIntID(-1), err)); err != nil { wsc.Logger.Error("Error writing RPC response", "err", err) } - go wsc.readRoutine() + go wsc.readRoutine(ctx) } }() @@ -317,7 +317,7 @@ func (wsc *wsConnection) readRoutine() { for { select { - case <-wsc.Quit(): + case <-ctx.Done(): return default: // reset deadline for every type of message (control or data) @@ -422,7 +422,7 @@ func (wsc *wsConnection) readRoutine() { } // receives on a write channel and writes out on the socket -func (wsc *wsConnection) writeRoutine() { +func (wsc *wsConnection) writeRoutine(ctx context.Context) { pingTicker := time.NewTicker(wsc.pingPeriod) defer pingTicker.Stop() @@ -438,7 +438,7 @@ func (wsc *wsConnection) writeRoutine() { for { select { - case <-wsc.Quit(): + case <-ctx.Done(): return case <-wsc.readRoutineQuit: // error in readRoutine return diff --git a/rpc/jsonrpc/test/main.go b/rpc/jsonrpc/test/main.go index 64f9de87a..8aadc3ec6 100644 --- a/rpc/jsonrpc/test/main.go +++ b/rpc/jsonrpc/test/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net/http" "os" @@ -29,8 +30,11 @@ func main() { logger = log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false) ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(logger, func() {}) + tmos.TrapSignal(ctx, logger, func() {}) rpcserver.RegisterRPCFuncs(mux, routes, logger) config := rpcserver.DefaultConfig() @@ -40,7 +44,7 @@ func main() { os.Exit(1) } - if err = rpcserver.Serve(listener, mux, logger, config); err != nil { + if err = rpcserver.Serve(ctx, listener, mux, logger, config); err != nil { logger.Error("rpc serve", "err", err) os.Exit(1) } diff --git a/rpc/jsonrpc/types/types.go b/rpc/jsonrpc/types/types.go index 4435c8c5d..74a4cc52b 100644 --- a/rpc/jsonrpc/types/types.go +++ b/rpc/jsonrpc/types/types.go @@ -253,7 +253,7 @@ type WSRPCConnection interface { // WriteRPCResponse writes the response onto connection (BLOCKING). WriteRPCResponse(context.Context, RPCResponse) error // TryWriteRPCResponse tries to write the response onto connection (NON-BLOCKING). - TryWriteRPCResponse(RPCResponse) bool + TryWriteRPCResponse(context.Context, RPCResponse) bool // Context returns the connection's context. Context() context.Context } diff --git a/test/e2e/node/main.go b/test/e2e/node/main.go index 5d25b0195..2509c9767 100644 --- a/test/e2e/node/main.go +++ b/test/e2e/node/main.go @@ -246,6 +246,10 @@ func startSigner(ctx context.Context, cfg *Config) error { if err := s.Serve(lis); err != nil { panic(err) } + go func() { + <-ctx.Done() + s.GracefulStop() + }() }() return nil diff --git a/test/e2e/runner/main.go b/test/e2e/runner/main.go index fb6ce4a8c..1144656e5 100644 --- a/test/e2e/runner/main.go +++ b/test/e2e/runner/main.go @@ -319,8 +319,7 @@ Does not run any perbutations. lctx, loadCancel := context.WithCancel(ctx) defer loadCancel() go func() { - err := Load(lctx, r, cli.testnet) - chLoadResult <- err + chLoadResult <- Load(lctx, r, cli.testnet) }() if err := Start(ctx, cli.testnet); err != nil {