From 824960c565c2896b208c269f2671bc19a043f964 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Mon, 14 Feb 2022 08:28:29 -0500 Subject: [PATCH] libs/service: regularize Stop semantics and concurrency primitives (#7809) --- cmd/tendermint/commands/rollback_test.go | 14 ++- cmd/tendermint/commands/run_node.go | 2 +- internal/blocksync/pool_test.go | 14 +-- internal/p2p/transport_mconn.go | 36 ++++-- internal/p2p/transport_mconn_test.go | 3 - libs/service/service.go | 133 ++++++++++++----------- libs/service/service_test.go | 130 ++++++++++++++++++---- node/node.go | 1 - node/node_test.go | 9 +- 9 files changed, 226 insertions(+), 116 deletions(-) diff --git a/cmd/tendermint/commands/rollback_test.go b/cmd/tendermint/commands/rollback_test.go index 167fbc1f3..43e25915f 100644 --- a/cmd/tendermint/commands/rollback_test.go +++ b/cmd/tendermint/commands/rollback_test.go @@ -22,7 +22,9 @@ func TestRollbackIntegration(t *testing.T) { cfg, err := rpctest.CreateConfig(t.Name()) require.NoError(t, err) cfg.BaseConfig.DBBackend = "goleveldb" + app, err := e2e.NewApplication(e2e.DefaultConfig(dir)) + require.NoError(t, err) t.Run("First run", func(t *testing.T) { ctx, cancel := context.WithCancel(ctx) @@ -30,27 +32,29 @@ func TestRollbackIntegration(t *testing.T) { require.NoError(t, err) node, _, err := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) require.NoError(t, err) + require.True(t, node.IsRunning()) time.Sleep(3 * time.Second) cancel() node.Wait() + require.False(t, node.IsRunning()) }) - t.Run("Rollback", func(t *testing.T) { + time.Sleep(time.Second) require.NoError(t, app.Rollback()) height, _, err = commands.RollbackState(cfg) - require.NoError(t, err) - + require.NoError(t, err, "%d", height) }) - t.Run("Restart", func(t *testing.T) { + require.True(t, height > 0, "%d", height) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() node2, _, err2 := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) require.NoError(t, err2) - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() client, err := local.New(logger, node2.(local.NodeService)) require.NoError(t, err) diff --git a/cmd/tendermint/commands/run_node.go b/cmd/tendermint/commands/run_node.go index afd3ae8f1..5f39fb21e 100644 --- a/cmd/tendermint/commands/run_node.go +++ b/cmd/tendermint/commands/run_node.go @@ -117,7 +117,7 @@ func NewRunNodeCmd(nodeProvider cfg.ServiceProvider, conf *cfg.Config, logger lo return fmt.Errorf("failed to start node: %w", err) } - logger.Info("started node", "node", n.String()) + logger.Info("started node", "chain", conf.ChainID()) <-ctx.Done() return nil diff --git a/internal/blocksync/pool_test.go b/internal/blocksync/pool_test.go index 0718fee16..0306a31c0 100644 --- a/internal/blocksync/pool_test.go +++ b/internal/blocksync/pool_test.go @@ -125,7 +125,6 @@ func TestBlockPoolBasic(t *testing.T) { case err := <-errorsCh: t.Error(err) case request := <-requestsCh: - t.Logf("Pulled new BlockRequest %v", request) if request.Height == 300 { return // Done! } @@ -139,21 +138,19 @@ func TestBlockPoolTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + logger := log.TestingLogger() + start := int64(42) peers := makePeers(10, start+1, 1000) errorsCh := make(chan peerError, 1000) requestsCh := make(chan BlockRequest, 1000) - pool := NewBlockPool(log.TestingLogger(), start, requestsCh, errorsCh) + pool := NewBlockPool(logger, start, requestsCh, errorsCh) err := pool.Start(ctx) if err != nil { t.Error(err) } t.Cleanup(func() { cancel(); pool.Wait() }) - for _, peer := range peers { - t.Logf("Peer %v", peer.id) - } - // Introduce each peer. go func() { for _, peer := range peers { @@ -182,7 +179,6 @@ func TestBlockPoolTimeout(t *testing.T) { for { select { case err := <-errorsCh: - t.Log(err) // consider error to be always timeout here if _, ok := timedOut[err.peerID]; !ok { counter++ @@ -191,7 +187,9 @@ func TestBlockPoolTimeout(t *testing.T) { } } case request := <-requestsCh: - t.Logf("Pulled new BlockRequest %+v", request) + logger.Debug("received request", + "counter", counter, + "request", request) } } } diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 46227ff8f..222dbf79c 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -138,19 +138,35 @@ func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { return nil, errors.New("transport is not listening") } - tcpConn, err := m.listener.Accept() - if err != nil { - select { - case <-ctx.Done(): - return nil, io.EOF - case <-m.doneCh: - return nil, io.EOF - default: - return nil, err + conCh := make(chan net.Conn) + errCh := make(chan error) + go func() { + tcpConn, err := m.listener.Accept() + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } } + select { + case conCh <- tcpConn: + case <-ctx.Done(): + } + }() + + select { + case <-ctx.Done(): + m.listener.Close() + return nil, io.EOF + case <-m.doneCh: + m.listener.Close() + return nil, io.EOF + case err := <-errCh: + return nil, err + case tcpConn := <-conCh: + return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } - return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } // Dial implements Transport. diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index 0851fe0e2..0f1c2e699 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -154,9 +154,6 @@ func TestMConnTransport_Listen(t *testing.T) { t.Run(tc.endpoint.String(), func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - ctx, cancel = context.WithCancel(ctx) - defer cancel() - transport := p2p.NewMConnTransport( log.TestingLogger(), conn.DefaultMConnConfig(), diff --git a/libs/service/service.go b/libs/service/service.go index b36aa1087..daeead03e 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -3,7 +3,7 @@ package service import ( "context" "errors" - "sync/atomic" + "sync" "github.com/tendermint/tendermint/libs/log" ) @@ -30,9 +30,6 @@ type Service interface { // Return true if the service is running IsRunning() bool - // String representation of the service - String() string - // Wait blocks until the service is stopped. Wait() } @@ -40,8 +37,6 @@ type Service interface { // Implementation describes the implementation that the // BaseService implementation wraps. type Implementation interface { - Service - // Called by the Services Start Method OnStart(context.Context) error @@ -57,12 +52,7 @@ Users can override the OnStart/OnStop methods. In the absence of errors, these methods are guaranteed to be called at most once. If OnStart returns an error, service won't be marked as started, so the user can call Start again. -A call to Reset will panic, unless OnReset is overwritten, allowing -OnStart/OnStop to be called again. - -The caller must ensure that Start and Stop are not called concurrently. - -It is ok to call Stop without calling Start first. +It is safe, but an error, to call Stop without calling Start first. Typical usage: @@ -80,23 +70,21 @@ Typical usage: } func (fs *FooService) OnStart(ctx context.Context) error { - fs.BaseService.OnStart() // Always call the overridden method. // initialize private fields // start subroutines, etc. } func (fs *FooService) OnStop() error { - fs.BaseService.OnStop() // Always call the overridden method. // close/destroy private fields // stop subroutines, etc. } */ type BaseService struct { - logger log.Logger - name string - started uint32 // atomic - stopped uint32 // atomic - quit chan struct{} + logger log.Logger + name string + mtx sync.Mutex + quit <-chan (struct{}) + cancel context.CancelFunc // The "subclass" of BaseService impl Implementation @@ -107,7 +95,6 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe return &BaseService{ logger: logger, name: name, - quit: make(chan struct{}), impl: impl, } } @@ -116,83 +103,101 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe // returned if the service is already running or stopped. To restart a // stopped service, call Reset. func (bs *BaseService) Start(ctx context.Context) error { - if atomic.CompareAndSwapUint32(&bs.started, 0, 1) { - if atomic.LoadUint32(&bs.stopped) == 1 { - bs.logger.Error("not starting service; already stopped", "service", bs.name, "impl", bs.impl.String()) - atomic.StoreUint32(&bs.started, 0) - return ErrAlreadyStopped - } + bs.mtx.Lock() + defer bs.mtx.Unlock() - bs.logger.Info("starting service", "service", bs.name, "impl", bs.impl.String()) + if bs.quit != nil { + return ErrAlreadyStarted + } + select { + case <-bs.quit: + return ErrAlreadyStopped + default: + bs.logger.Info("starting service", "service", bs.name, "impl", bs.name) if err := bs.impl.OnStart(ctx); err != nil { - // revert flag - atomic.StoreUint32(&bs.started, 0) return err } + // we need a separate context to ensure that we start + // a thread that will get cleaned up and that the + // Stop/Wait functions work as expected. + srvCtx, cancel := context.WithCancel(context.Background()) + bs.cancel = cancel + bs.quit = srvCtx.Done() + go func(ctx context.Context) { select { - case <-bs.quit: - // someone else explicitly called stop - // and then we shouldn't. + case <-srvCtx.Done(): + // this means stop was called manually return case <-ctx.Done(): - // if nothing is running, no need to - // shut down again. - if !bs.impl.IsRunning() { - return - } - - // the context was cancel and we - // should stop. - if err := bs.Stop(); err != nil { - bs.logger.Error("stopped service", - "err", err.Error(), - "service", bs.name, - "impl", bs.impl.String()) - } - - bs.logger.Info("stopped service", - "service", bs.name, - "impl", bs.impl.String()) + _ = bs.Stop() } + + bs.logger.Info("stopped service", + "service", bs.name) }(ctx) return nil } - - return ErrAlreadyStarted } // Stop implements Service by calling OnStop (if defined) and closing quit // channel. An error will be returned if the service is already stopped. func (bs *BaseService) Stop() error { - if atomic.CompareAndSwapUint32(&bs.stopped, 0, 1) { - if atomic.LoadUint32(&bs.started) == 0 { - bs.logger.Error("not stopping service; not started yet", "service", bs.name, "impl", bs.impl.String()) - atomic.StoreUint32(&bs.stopped, 0) - return ErrNotStarted - } + bs.mtx.Lock() + defer bs.mtx.Unlock() - bs.logger.Info("stopping service", "service", bs.name, "impl", bs.impl.String()) + if bs.quit == nil { + return ErrNotStarted + } + + select { + case <-bs.quit: + return ErrAlreadyStopped + default: + bs.logger.Info("stopping service", "service", bs.name) bs.impl.OnStop() - close(bs.quit) + bs.cancel() return nil } - - return ErrAlreadyStopped } // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { - return atomic.LoadUint32(&bs.started) == 1 && atomic.LoadUint32(&bs.stopped) == 0 + bs.mtx.Lock() + defer bs.mtx.Unlock() + + if bs.quit == nil { + return false + } + + select { + case <-bs.quit: + return false + default: + return true + } +} + +func (bs *BaseService) getWait() <-chan struct{} { + bs.mtx.Lock() + defer bs.mtx.Unlock() + + if bs.quit == nil { + out := make(chan struct{}) + close(out) + return out + } + + return bs.quit } // Wait blocks until the service is stopped. -func (bs *BaseService) Wait() { <-bs.quit } +func (bs *BaseService) Wait() { <-bs.getWait() } // String implements Service by returning a string representation of the service. func (bs *BaseService) String() string { return bs.name } diff --git a/libs/service/service_test.go b/libs/service/service_test.go index fcc727fcc..5b1e9798f 100644 --- a/libs/service/service_test.go +++ b/libs/service/service_test.go @@ -2,45 +2,135 @@ package service import ( "context" + "sync" "testing" "time" + "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/libs/log" ) type testService struct { + started bool + stopped bool + multiStopped bool + mu sync.Mutex BaseService } -func (testService) OnStop() {} -func (testService) OnStart(context.Context) error { +func (t *testService) OnStop() { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped == true { + t.multiStopped = true + } + t.stopped = true +} +func (t *testService) OnStart(context.Context) error { + t.mu.Lock() + defer t.mu.Unlock() + + t.started = true return nil } -func TestBaseServiceWait(t *testing.T) { +func (t *testService) isStarted() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.started +} + +func (t *testService) isStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.stopped +} + +func (t *testService) isMultiStopped() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.multiStopped +} + +func TestBaseService(t *testing.T) { + t.Cleanup(leaktest.Check(t)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() - ts := &testService{} - ts.BaseService = *NewBaseService(logger, "TestService", ts) - err := ts.Start(ctx) - require.NoError(t, err) + t.Run("Wait", func(t *testing.T) { + wctx, wcancel := context.WithCancel(ctx) + defer wcancel() + ts := &testService{} + ts.BaseService = *NewBaseService(logger, t.Name(), ts) + err := ts.Start(wctx) + require.NoError(t, err) + require.True(t, ts.isStarted()) - waitFinished := make(chan struct{}) - go func() { - ts.Wait() - waitFinished <- struct{}{} - }() + waitFinished := make(chan struct{}) + wcancel() + go func() { + ts.Wait() + close(waitFinished) + }() - go cancel() + select { + case <-waitFinished: + assert.True(t, ts.isStopped(), "failed to stop") + assert.False(t, ts.IsRunning(), "is not running") + + case <-time.After(100 * time.Millisecond): + t.Fatal("expected Wait() to finish within 100 ms.") + } + }) + t.Run("ManualStop", func(t *testing.T) { + ts := &testService{} + ts.BaseService = *NewBaseService(logger, t.Name(), ts) + require.False(t, ts.IsRunning()) + require.False(t, ts.isStarted()) + require.NoError(t, ts.Start(ctx)) + + require.True(t, ts.isStarted()) + + require.NoError(t, ts.Stop()) + require.True(t, ts.isStopped()) + require.False(t, ts.IsRunning()) + }) + t.Run("MultiStop", func(t *testing.T) { + t.Run("SingleThreaded", func(t *testing.T) { + ts := &testService{} + ts.BaseService = *NewBaseService(logger, t.Name(), ts) + + require.NoError(t, ts.Start(ctx)) + require.True(t, ts.isStarted()) + require.NoError(t, ts.Stop()) + require.True(t, ts.isStopped()) + require.False(t, ts.isMultiStopped()) + require.Error(t, ts.Stop()) + require.False(t, ts.isMultiStopped()) + }) + t.Run("MultiThreaded", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts := &testService{} + ts.BaseService = *NewBaseService(logger, t.Name(), ts) + + require.NoError(t, ts.Start(ctx)) + require.True(t, ts.isStarted()) + + go func() { _ = ts.Stop() }() + go cancel() + + ts.Wait() + + require.True(t, ts.isStopped()) + require.False(t, ts.isMultiStopped()) + }) + + }) - select { - case <-waitFinished: - // all good - case <-time.After(100 * time.Millisecond): - t.Fatal("expected Wait() to finish within 100 ms.") - } } diff --git a/node/node.go b/node/node.go index 84929a58c..d89238d37 100644 --- a/node/node.go +++ b/node/node.go @@ -550,7 +550,6 @@ func (n *nodeImpl) OnStart(ctx context.Context) error { // OnStop stops the Node. It implements service.Service. func (n *nodeImpl) OnStop() { n.logger.Info("Stopping Node") - for _, es := range n.eventSinks { if err := es.Stop(); err != nil { n.logger.Error("failed to stop event sink", "err", err) diff --git a/node/node_test.go b/node/node_test.go index eafc5ebdd..116319294 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -55,11 +55,10 @@ func TestNodeStartStop(t *testing.T) { n, ok := ns.(*nodeImpl) require.True(t, ok) t.Cleanup(func() { - if n.IsRunning() { - bcancel() - n.Wait() - } + bcancel() + n.Wait() }) + t.Cleanup(leaktest.CheckTimeout(t, time.Second)) require.NoError(t, n.Start(ctx)) // wait for the node to produce a block @@ -98,6 +97,7 @@ func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger ns.Wait() } }) + t.Cleanup(leaktest.CheckTimeout(t, time.Second)) return n @@ -568,6 +568,7 @@ func TestNodeNewSeedNode(t *testing.T) { logger, ) t.Cleanup(ns.Wait) + t.Cleanup(leaktest.CheckTimeout(t, time.Second)) require.NoError(t, err) n, ok := ns.(*seedNodeImpl)