diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index 931e2e7fc..58f4033d8 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -269,7 +269,10 @@ func (r *Reactor) OnStop() { func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { // We need at least two peers (for cross-referencing of light blocks) before we can // begin state sync - r.waitForEnoughPeers(ctx, 2) + if err := r.waitForEnoughPeers(ctx, 2); err != nil { + return sm.State{}, err + } + r.mtx.Lock() if r.syncer != nil { r.mtx.Unlock() @@ -288,6 +291,7 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { r.stateProvider, r.snapshotCh.Out, r.chunkCh.Out, + r.snapshotCh.Done(), r.tempDir, r.metrics, ) @@ -302,10 +306,16 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { requestSnapshotsHook := func() { // request snapshots from all currently connected peers - r.snapshotCh.Out <- p2p.Envelope{ + msg := p2p.Envelope{ Broadcast: true, Message: &ssproto.SnapshotsRequest{}, } + + select { + case <-ctx.Done(): + case <-r.closeCh: + case r.snapshotCh.Out <- msg: + } } state, commit, err := r.syncer.SyncAny(ctx, r.cfg.DiscoveryTime, requestSnapshotsHook) @@ -992,19 +1002,21 @@ func (r *Reactor) fetchLightBlock(height uint64) (*types.LightBlock, error) { }, nil } -func (r *Reactor) waitForEnoughPeers(ctx context.Context, numPeers int) { +func (r *Reactor) waitForEnoughPeers(ctx context.Context, numPeers int) error { + startAt := time.Now() t := time.NewTicker(200 * time.Millisecond) defer t.Stop() - for { + for r.peers.Len() < numPeers { select { case <-ctx.Done(): - return + return fmt.Errorf("operation canceled while waiting for peers after %s", time.Since(startAt)) + case <-r.closeCh: + return fmt.Errorf("shutdown while waiting for peers after %s", time.Since(startAt)) case <-t.C: - if r.peers.Len() >= numPeers { - return - } + continue } } + return nil } func (r *Reactor) initStateProvider(ctx context.Context, chainID string, initialHeight int64) error { @@ -1019,6 +1031,10 @@ func (r *Reactor) initStateProvider(ctx context.Context, chainID string, initial "trustHeight", to.Height, "useP2P", r.cfg.UseP2P) if r.cfg.UseP2P { + if err := r.waitForEnoughPeers(ctx, 2); err != nil { + return err + } + peers := r.peers.All() providers := make([]provider.Provider, len(peers)) for idx, p := range peers { diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 98ee9d26e..41dcf3d2d 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -171,6 +171,7 @@ func setup( stateProvider, rts.snapshotOutCh, rts.chunkOutCh, + rts.snapshotChannel.Done(), "", rts.reactor.metrics, ) @@ -524,7 +525,9 @@ func TestReactor_StateProviderP2P(t *testing.T) { rts.reactor.cfg.UseP2P = true rts.reactor.cfg.TrustHeight = 1 rts.reactor.cfg.TrustHash = fmt.Sprintf("%X", chain[1].Hash()) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rts.reactor.mtx.Lock() err := rts.reactor.initStateProvider(ctx, factory.DefaultTestChainID, 1) rts.reactor.mtx.Unlock() diff --git a/internal/statesync/syncer.go b/internal/statesync/syncer.go index 43fbf99e4..68bec6880 100644 --- a/internal/statesync/syncer.go +++ b/internal/statesync/syncer.go @@ -70,6 +70,7 @@ type syncer struct { avgChunkTime int64 lastSyncedSnapshotHeight int64 processingSnapshot *snapshot + closeCh <-chan struct{} } // newSyncer creates a new syncer. @@ -79,7 +80,9 @@ func newSyncer( conn proxy.AppConnSnapshot, connQuery proxy.AppConnQuery, stateProvider StateProvider, - snapshotCh, chunkCh chan<- p2p.Envelope, + snapshotCh chan<- p2p.Envelope, + chunkCh chan<- p2p.Envelope, + closeCh <-chan struct{}, tempDir string, metrics *Metrics, ) *syncer { @@ -95,6 +98,7 @@ func newSyncer( fetchers: cfg.Fetchers, retryTimeout: cfg.ChunkRequestTimeout, metrics: metrics, + closeCh: closeCh, } } @@ -139,10 +143,16 @@ func (s *syncer) AddSnapshot(peerID types.NodeID, snapshot *snapshot) (bool, err // single request to discover snapshots, later we may want to do retries and stuff. func (s *syncer) AddPeer(peerID types.NodeID) { s.logger.Debug("Requesting snapshots from peer", "peer", peerID) - s.snapshotCh <- p2p.Envelope{ + + msg := p2p.Envelope{ To: peerID, Message: &ssproto.SnapshotsRequest{}, } + + select { + case <-s.closeCh: + case s.snapshotCh <- msg: + } } // RemovePeer removes a peer from the pool. @@ -473,6 +483,8 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch select { case <-ctx.Done(): return + case <-s.closeCh: + return case <-time.After(2 * time.Second): continue } @@ -499,6 +511,8 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, chunks *ch case <-ctx.Done(): return + case <-s.closeCh: + return } ticker.Stop() @@ -522,7 +536,7 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { "peer", peer, ) - s.chunkCh <- p2p.Envelope{ + msg := p2p.Envelope{ To: peer, Message: &ssproto.ChunkRequest{ Height: snapshot.Height, @@ -530,6 +544,11 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { Index: chunk, }, } + + select { + case s.chunkCh <- msg: + case <-s.closeCh: + } } // verifyApp verifies the sync, checking the app hash and last block height. It returns the