diff --git a/abci/client/grpc_client.go b/abci/client/grpc_client.go index f4cd5f3e9..ef88736ab 100644 --- a/abci/client/grpc_client.go +++ b/abci/client/grpc_client.go @@ -63,7 +63,7 @@ func dialerFunc(ctx context.Context, addr string) (net.Conn, error) { return tmnet.Connect(addr) } -func (cli *grpcClient) OnStart() error { +func (cli *grpcClient) OnStart(ctx context.Context) error { // This processes asynchronous request/response messages and dispatches // them to callbacks. go func() { diff --git a/abci/client/mocks/client.go b/abci/client/mocks/client.go index 664646e61..f0d82a50e 100644 --- a/abci/client/mocks/client.go +++ b/abci/client/mocks/client.go @@ -7,8 +7,6 @@ import ( abciclient "github.com/tendermint/tendermint/abci/client" - log "github.com/tendermint/tendermint/libs/log" - mock "github.com/stretchr/testify/mock" types "github.com/tendermint/tendermint/abci/types" @@ -636,39 +634,6 @@ func (_m *Client) OfferSnapshotSync(_a0 context.Context, _a1 types.RequestOfferS return r0, r1 } -// OnReset provides a mock function with given fields: -func (_m *Client) OnReset() error { - ret := _m.Called() - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// OnStart provides a mock function with given fields: -func (_m *Client) OnStart() error { - ret := _m.Called() - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// OnStop provides a mock function with given fields: -func (_m *Client) OnStop() { - _m.Called() -} - // QueryAsync provides a mock function with given fields: _a0, _a1 func (_m *Client) QueryAsync(_a0 context.Context, _a1 types.RequestQuery) (*abciclient.ReqRes, error) { ret := _m.Called(_a0, _a1) @@ -731,51 +696,18 @@ func (_m *Client) Quit() <-chan struct{} { return r0 } -// Reset provides a mock function with given fields: -func (_m *Client) Reset() error { - ret := _m.Called() - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SetLogger provides a mock function with given fields: _a0 -func (_m *Client) SetLogger(_a0 log.Logger) { - _m.Called(_a0) -} - // SetResponseCallback provides a mock function with given fields: _a0 func (_m *Client) SetResponseCallback(_a0 abciclient.Callback) { _m.Called(_a0) } -// Start provides a mock function with given fields: -func (_m *Client) Start() error { - ret := _m.Called() +// Start provides a mock function with given fields: _a0 +func (_m *Client) Start(_a0 context.Context) error { + ret := _m.Called(_a0) var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Stop provides a mock function with given fields: -func (_m *Client) Stop() error { - ret := _m.Called() - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) } else { r0 = ret.Error(0) } diff --git a/abci/client/socket_client.go b/abci/client/socket_client.go index 00e981123..8dfee0c8d 100644 --- a/abci/client/socket_client.go +++ b/abci/client/socket_client.go @@ -66,7 +66,7 @@ func NewSocketClient(logger log.Logger, addr string, mustConnect bool) Client { // OnStart implements Service by connecting to the server and spawning reading // and writing goroutines. -func (cli *socketClient) OnStart() error { +func (cli *socketClient) OnStart(ctx context.Context) error { var ( err error conn net.Conn @@ -85,8 +85,8 @@ func (cli *socketClient) OnStart() error { } cli.conn = conn - go cli.sendRequestsRoutine(conn) - go cli.recvResponseRoutine(conn) + go cli.sendRequestsRoutine(ctx, conn) + go cli.recvResponseRoutine(ctx, conn) return nil } @@ -114,17 +114,25 @@ func (cli *socketClient) Error() error { // NOTE: callback may get internally generated flush responses. func (cli *socketClient) SetResponseCallback(resCb Callback) { cli.mtx.Lock() + defer cli.mtx.Unlock() cli.resCb = resCb - cli.mtx.Unlock() } //---------------------------------------- -func (cli *socketClient) sendRequestsRoutine(conn io.Writer) { +func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer) { bw := bufio.NewWriter(conn) for { select { + case <-ctx.Done(): + return + case <-cli.Quit(): + return case reqres := <-cli.reqQueue: + if ctx.Err() != nil { + return + } + if reqres.C.Err() != nil { cli.Logger.Debug("Request's context is done", "req", reqres.R, "err", reqres.C.Err()) continue @@ -139,16 +147,16 @@ func (cli *socketClient) sendRequestsRoutine(conn io.Writer) { cli.stopForError(fmt.Errorf("flush buffer: %w", err)) return } - - case <-cli.Quit(): - return } } } -func (cli *socketClient) recvResponseRoutine(conn io.Reader) { +func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader) { r := bufio.NewReader(conn) for { + if ctx.Err() != nil { + return + } var res = &types.Response{} err := types.ReadMessage(r, res) if err != nil { diff --git a/abci/client/socket_client_test.go b/abci/client/socket_client_test.go index 0e2fdffa3..a3469ddd1 100644 --- a/abci/client/socket_client_test.go +++ b/abci/client/socket_client_test.go @@ -18,30 +18,21 @@ import ( "github.com/tendermint/tendermint/libs/service" ) -var ctx = context.Background() - func TestProperSyncCalls(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := slowApp{} logger := log.TestingLogger() - s, c := setupClientServer(t, logger, app) - t.Cleanup(func() { - if err := s.Stop(); err != nil { - t.Error(err) - } - }) - t.Cleanup(func() { - if err := c.Stop(); err != nil { - t.Error(err) - } - }) + _, c := setupClientServer(ctx, t, logger, app) resp := make(chan error, 1) go func() { // This is BeginBlockSync unrolled.... reqres, err := c.BeginBlockAsync(ctx, types.RequestBeginBlock{}) assert.NoError(t, err) - err = c.FlushSync(context.Background()) + err = c.FlushSync(ctx) assert.NoError(t, err) res := reqres.Response.GetBeginBlock() assert.NotNil(t, res) @@ -57,52 +48,8 @@ func TestProperSyncCalls(t *testing.T) { } } -func TestHangingSyncCalls(t *testing.T) { - app := slowApp{} - logger := log.TestingLogger() - - s, c := setupClientServer(t, logger, app) - t.Cleanup(func() { - if err := s.Stop(); err != nil { - t.Log(err) - } - }) - t.Cleanup(func() { - if err := c.Stop(); err != nil { - t.Log(err) - } - }) - - resp := make(chan error, 1) - go func() { - // Start BeginBlock and flush it - reqres, err := c.BeginBlockAsync(ctx, types.RequestBeginBlock{}) - assert.NoError(t, err) - flush, err := c.FlushAsync(ctx) - assert.NoError(t, err) - // wait 20 ms for all events to travel socket, but - // no response yet from server - time.Sleep(20 * time.Millisecond) - // kill the server, so the connections break - err = s.Stop() - assert.NoError(t, err) - - // wait for the response from BeginBlock - reqres.Wait() - flush.Wait() - resp <- c.Error() - }() - - select { - case <-time.After(time.Second): - require.Fail(t, "No response arrived") - case err, ok := <-resp: - require.True(t, ok, "Must not close channel") - assert.Error(t, err, "We should get EOF error") - } -} - func setupClientServer( + ctx context.Context, t *testing.T, logger log.Logger, app types.Application, @@ -115,12 +62,15 @@ func setupClientServer( s, err := server.NewServer(logger, addr, "socket", app) require.NoError(t, err) - err = s.Start() - require.NoError(t, err) + require.NoError(t, s.Start(ctx)) + t.Cleanup(s.Wait) c := abciclient.NewSocketClient(logger, addr, true) - err = c.Start() - require.NoError(t, err) + require.NoError(t, c.Start(ctx)) + t.Cleanup(c.Wait) + + require.True(t, s.IsRunning()) + require.True(t, c.IsRunning()) return s, c } diff --git a/abci/cmd/abci-cli/abci-cli.go b/abci/cmd/abci-cli/abci-cli.go index 3ffc7dbfa..783c41dbb 100644 --- a/abci/cmd/abci-cli/abci-cli.go +++ b/abci/cmd/abci-cli/abci-cli.go @@ -2,18 +2,18 @@ package main import ( "bufio" - "context" "encoding/hex" "errors" "fmt" "io" "os" + "os/signal" "strings" + "syscall" "github.com/spf13/cobra" "github.com/tendermint/tendermint/libs/log" - tmos "github.com/tendermint/tendermint/libs/os" abciclient "github.com/tendermint/tendermint/abci/client" "github.com/tendermint/tendermint/abci/example/code" @@ -29,8 +29,6 @@ import ( var ( client abciclient.Client logger log.Logger - - ctx = context.Background() ) // flags @@ -71,7 +69,8 @@ var RootCmd = &cobra.Command{ if err != nil { return err } - if err := client.Start(); err != nil { + + if err := client.Start(cmd.Context()); err != nil { return err } } @@ -291,23 +290,24 @@ func compose(fs []func() error) error { } func cmdTest(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() return compose( []func() error{ - func() error { return servertest.InitChain(client) }, - func() error { return servertest.Commit(client, nil) }, - func() error { return servertest.DeliverTx(client, []byte("abc"), code.CodeTypeBadNonce, nil) }, - func() error { return servertest.Commit(client, nil) }, - func() error { return servertest.DeliverTx(client, []byte{0x00}, code.CodeTypeOK, nil) }, - func() error { return servertest.Commit(client, []byte{0, 0, 0, 0, 0, 0, 0, 1}) }, - func() error { return servertest.DeliverTx(client, []byte{0x00}, code.CodeTypeBadNonce, nil) }, - func() error { return servertest.DeliverTx(client, []byte{0x01}, code.CodeTypeOK, nil) }, - func() error { return servertest.DeliverTx(client, []byte{0x00, 0x02}, code.CodeTypeOK, nil) }, - func() error { return servertest.DeliverTx(client, []byte{0x00, 0x03}, code.CodeTypeOK, nil) }, - func() error { return servertest.DeliverTx(client, []byte{0x00, 0x00, 0x04}, code.CodeTypeOK, nil) }, + func() error { return servertest.InitChain(ctx, client) }, + func() error { return servertest.Commit(ctx, client, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte("abc"), code.CodeTypeBadNonce, nil) }, + func() error { return servertest.Commit(ctx, client, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x00}, code.CodeTypeOK, nil) }, + func() error { return servertest.Commit(ctx, client, []byte{0, 0, 0, 0, 0, 0, 0, 1}) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x00}, code.CodeTypeBadNonce, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x01}, code.CodeTypeOK, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x00, 0x02}, code.CodeTypeOK, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x00, 0x03}, code.CodeTypeOK, nil) }, + func() error { return servertest.DeliverTx(ctx, client, []byte{0x00, 0x00, 0x04}, code.CodeTypeOK, nil) }, func() error { - return servertest.DeliverTx(client, []byte{0x00, 0x00, 0x06}, code.CodeTypeBadNonce, nil) + return servertest.DeliverTx(ctx, client, []byte{0x00, 0x00, 0x06}, code.CodeTypeBadNonce, nil) }, - func() error { return servertest.Commit(client, []byte{0, 0, 0, 0, 0, 0, 0, 5}) }, + func() error { return servertest.Commit(ctx, client, []byte{0, 0, 0, 0, 0, 0, 0, 5}) }, }) } @@ -442,13 +442,15 @@ func cmdEcho(cmd *cobra.Command, args []string) error { if len(args) > 0 { msg = args[0] } - res, err := client.EchoSync(ctx, msg) + res, err := client.EchoSync(cmd.Context(), msg) if err != nil { return err } + printResponse(cmd, args, response{ Data: []byte(res.Message), }) + return nil } @@ -458,7 +460,7 @@ func cmdInfo(cmd *cobra.Command, args []string) error { if len(args) == 1 { version = args[0] } - res, err := client.InfoSync(ctx, types.RequestInfo{Version: version}) + res, err := client.InfoSync(cmd.Context(), types.RequestInfo{Version: version}) if err != nil { return err } @@ -483,7 +485,7 @@ func cmdDeliverTx(cmd *cobra.Command, args []string) error { if err != nil { return err } - res, err := client.DeliverTxSync(ctx, types.RequestDeliverTx{Tx: txBytes}) + res, err := client.DeliverTxSync(cmd.Context(), types.RequestDeliverTx{Tx: txBytes}) if err != nil { return err } @@ -509,7 +511,7 @@ func cmdCheckTx(cmd *cobra.Command, args []string) error { if err != nil { return err } - res, err := client.CheckTxSync(ctx, types.RequestCheckTx{Tx: txBytes}) + res, err := client.CheckTxSync(cmd.Context(), types.RequestCheckTx{Tx: txBytes}) if err != nil { return err } @@ -524,7 +526,7 @@ func cmdCheckTx(cmd *cobra.Command, args []string) error { // Get application Merkle root hash func cmdCommit(cmd *cobra.Command, args []string) error { - res, err := client.CommitSync(ctx) + res, err := client.CommitSync(cmd.Context()) if err != nil { return err } @@ -549,7 +551,7 @@ func cmdQuery(cmd *cobra.Command, args []string) error { return err } - resQuery, err := client.QuerySync(ctx, types.RequestQuery{ + resQuery, err := client.QuerySync(cmd.Context(), types.RequestQuery{ Data: queryBytes, Path: flagPath, Height: int64(flagHeight), @@ -590,20 +592,16 @@ func cmdKVStore(cmd *cobra.Command, args []string) error { return err } - if err := srv.Start(); err != nil { + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGTERM) + defer cancel() + + if err := srv.Start(ctx); err != nil { return err } - // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(logger, func() { - // Cleanup - if err := srv.Stop(); err != nil { - logger.Error("Error while stopping server", "err", err) - } - }) - // Run forever. - select {} + <-ctx.Done() + return nil } //-------------------------------------------------------------------------------- diff --git a/abci/example/example_test.go b/abci/example/example_test.go index c984fa2fb..80d5a3130 100644 --- a/abci/example/example_test.go +++ b/abci/example/example_test.go @@ -29,21 +29,29 @@ func init() { } func TestKVStore(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fmt.Println("### Testing KVStore") - testStream(t, kvstore.NewApplication()) + testStream(ctx, t, kvstore.NewApplication()) } func TestBaseApp(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() fmt.Println("### Testing BaseApp") - testStream(t, types.NewBaseApplication()) + testStream(ctx, t, types.NewBaseApplication()) } func TestGRPC(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fmt.Println("### Testing GRPC") - testGRPCSync(t, types.NewGRPCApplication(types.NewBaseApplication())) + testGRPCSync(ctx, t, types.NewGRPCApplication(types.NewBaseApplication())) } -func testStream(t *testing.T, app types.Application) { +func testStream(ctx context.Context, t *testing.T, app types.Application) { t.Helper() const numDeliverTxs = 20000 @@ -53,25 +61,16 @@ func testStream(t *testing.T, app types.Application) { logger := log.TestingLogger() // Start the listener server := abciserver.NewSocketServer(logger.With("module", "abci-server"), socket, app) - - err := server.Start() + t.Cleanup(server.Wait) + err := server.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := server.Stop(); err != nil { - t.Error(err) - } - }) // Connect to the socket client := abciclient.NewSocketClient(log.TestingLogger().With("module", "abci-client"), socket, false) + t.Cleanup(client.Wait) - err = client.Start() + err = client.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := client.Stop(); err != nil { - t.Error(err) - } - }) done := make(chan struct{}) counter := 0 @@ -100,8 +99,6 @@ func testStream(t *testing.T, app types.Application) { } }) - ctx := context.Background() - // Write requests for counter := 0; counter < numDeliverTxs; counter++ { // Send request @@ -129,7 +126,7 @@ func dialerFunc(ctx context.Context, addr string) (net.Conn, error) { return tmnet.Connect(addr) } -func testGRPCSync(t *testing.T, app types.ABCIApplicationServer) { +func testGRPCSync(ctx context.Context, t *testing.T, app types.ABCIApplicationServer) { numDeliverTxs := 2000 socketFile := fmt.Sprintf("/tmp/test-%08x.sock", rand.Int31n(1<<30)) defer os.Remove(socketFile) @@ -138,15 +135,11 @@ func testGRPCSync(t *testing.T, app types.ABCIApplicationServer) { // Start the listener server := abciserver.NewGRPCServer(logger.With("module", "abci-server"), socket, app) - if err := server.Start(); err != nil { + if err := server.Start(ctx); err != nil { t.Fatalf("Error starting GRPC server: %v", err.Error()) } - t.Cleanup(func() { - if err := server.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { server.Wait() }) // Connect to the socket conn, err := grpc.Dial(socket, grpc.WithInsecure(), grpc.WithContextDialer(dialerFunc)) diff --git a/abci/example/kvstore/kvstore_test.go b/abci/example/kvstore/kvstore_test.go index e64e0ed9e..664e628b0 100644 --- a/abci/example/kvstore/kvstore_test.go +++ b/abci/example/kvstore/kvstore_test.go @@ -24,8 +24,6 @@ const ( testValue = "def" ) -var ctx = context.Background() - func testKVStore(t *testing.T, app types.Application, tx []byte, key, value string) { req := types.RequestDeliverTx{Tx: tx} ar := app.DeliverTx(req) @@ -229,101 +227,103 @@ func valsEqual(t *testing.T, vals1, vals2 []types.ValidatorUpdate) { } } -func makeSocketClientServer(app types.Application, name string) (abciclient.Client, service.Service, error) { +func makeSocketClientServer( + ctx context.Context, + t *testing.T, + logger log.Logger, + app types.Application, + name string, +) (abciclient.Client, service.Service, error) { + + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + // Start the listener socket := fmt.Sprintf("unix://%s.sock", name) - logger := log.TestingLogger() server := abciserver.NewSocketServer(logger.With("module", "abci-server"), socket, app) - if err := server.Start(); err != nil { + if err := server.Start(ctx); err != nil { + cancel() return nil, nil, err } // Connect to the socket client := abciclient.NewSocketClient(logger.With("module", "abci-client"), socket, false) - if err := client.Start(); err != nil { - if err = server.Stop(); err != nil { - return nil, nil, err - } + if err := client.Start(ctx); err != nil { + cancel() return nil, nil, err } return client, server, nil } -func makeGRPCClientServer(app types.Application, name string) (abciclient.Client, service.Service, error) { +func makeGRPCClientServer( + ctx context.Context, + t *testing.T, + logger log.Logger, + app types.Application, + name string, +) (abciclient.Client, service.Service, error) { + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) // Start the listener socket := fmt.Sprintf("unix://%s.sock", name) - logger := log.TestingLogger() gapp := types.NewGRPCApplication(app) server := abciserver.NewGRPCServer(logger.With("module", "abci-server"), socket, gapp) - if err := server.Start(); err != nil { + if err := server.Start(ctx); err != nil { + cancel() return nil, nil, err } client := abciclient.NewGRPCClient(logger.With("module", "abci-client"), socket, true) - if err := client.Start(); err != nil { - if err := server.Stop(); err != nil { - return nil, nil, err - } + if err := client.Start(ctx); err != nil { + cancel() return nil, nil, err } return client, server, nil } func TestClientServer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.TestingLogger() + // set up socket app kvstore := NewApplication() - client, server, err := makeSocketClientServer(kvstore, "kvstore-socket") + client, server, err := makeSocketClientServer(ctx, t, logger, kvstore, "kvstore-socket") require.NoError(t, err) - t.Cleanup(func() { - if err := server.Stop(); err != nil { - t.Error(err) - } - }) - t.Cleanup(func() { - if err := client.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); server.Wait() }) + t.Cleanup(func() { cancel(); client.Wait() }) - runClientTests(t, client) + runClientTests(ctx, t, client) // set up grpc app kvstore = NewApplication() - gclient, gserver, err := makeGRPCClientServer(kvstore, "/tmp/kvstore-grpc") + gclient, gserver, err := makeGRPCClientServer(ctx, t, logger, kvstore, "/tmp/kvstore-grpc") require.NoError(t, err) - t.Cleanup(func() { - if err := gserver.Stop(); err != nil { - t.Error(err) - } - }) - t.Cleanup(func() { - if err := gclient.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); gserver.Wait() }) + t.Cleanup(func() { cancel(); gclient.Wait() }) - runClientTests(t, gclient) + runClientTests(ctx, t, gclient) } -func runClientTests(t *testing.T, client abciclient.Client) { +func runClientTests(ctx context.Context, t *testing.T, client abciclient.Client) { // run some tests.... key := testKey value := key tx := []byte(key) - testClient(t, client, tx, key, value) + testClient(ctx, t, client, tx, key, value) value = testValue tx = []byte(key + "=" + value) - testClient(t, client, tx, key, value) + testClient(ctx, t, client, tx, key, value) } -func testClient(t *testing.T, app abciclient.Client, tx []byte, key, value string) { +func testClient(ctx context.Context, t *testing.T, app abciclient.Client, tx []byte, key, value string) { ar, err := app.DeliverTxSync(ctx, types.RequestDeliverTx{Tx: tx}) require.NoError(t, err) require.False(t, ar.IsErr(), ar) diff --git a/abci/server/grpc_server.go b/abci/server/grpc_server.go index 6d22f43c1..78da22cdb 100644 --- a/abci/server/grpc_server.go +++ b/abci/server/grpc_server.go @@ -1,6 +1,7 @@ package server import ( + "context" "net" "google.golang.org/grpc" @@ -36,7 +37,7 @@ func NewGRPCServer(logger log.Logger, protoAddr string, app types.ABCIApplicatio } // OnStart starts the gRPC service. -func (s *GRPCServer) OnStart() error { +func (s *GRPCServer) OnStart(ctx context.Context) error { ln, err := net.Listen(s.proto, s.addr) if err != nil { diff --git a/abci/server/socket_server.go b/abci/server/socket_server.go index b24b38e38..29d912671 100644 --- a/abci/server/socket_server.go +++ b/abci/server/socket_server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "context" "fmt" "io" "net" @@ -44,14 +45,14 @@ func NewSocketServer(logger tmlog.Logger, protoAddr string, app types.Applicatio return s } -func (s *SocketServer) OnStart() error { +func (s *SocketServer) OnStart(ctx context.Context) error { ln, err := net.Listen(s.proto, s.addr) if err != nil { return err } s.listener = ln - go s.acceptConnectionsRoutine() + go s.acceptConnectionsRoutine(ctx) return nil } @@ -63,6 +64,7 @@ func (s *SocketServer) OnStop() { s.connsMtx.Lock() defer s.connsMtx.Unlock() + for id, conn := range s.conns { delete(s.conns, id) if err := conn.Close(); err != nil { @@ -96,8 +98,13 @@ func (s *SocketServer) rmConn(connID int) error { return conn.Close() } -func (s *SocketServer) acceptConnectionsRoutine() { +func (s *SocketServer) acceptConnectionsRoutine(ctx context.Context) { for { + if ctx.Err() != nil { + return + + } + // Accept a connection s.Logger.Info("Waiting for new connection...") conn, err := s.listener.Accept() @@ -117,35 +124,46 @@ func (s *SocketServer) acceptConnectionsRoutine() { responses := make(chan *types.Response, 1000) // A channel to buffer responses // Read requests from conn and deal with them - go s.handleRequests(closeConn, conn, responses) + go s.handleRequests(ctx, closeConn, conn, responses) // Pull responses from 'responses' and write them to conn. - go s.handleResponses(closeConn, conn, responses) + go s.handleResponses(ctx, closeConn, conn, responses) // Wait until signal to close connection - go s.waitForClose(closeConn, connID) + go s.waitForClose(ctx, closeConn, connID) } } -func (s *SocketServer) waitForClose(closeConn chan error, connID int) { - err := <-closeConn - switch { - case err == io.EOF: - s.Logger.Error("Connection was closed by client") - case err != nil: - s.Logger.Error("Connection error", "err", err) - default: - // never happens - s.Logger.Error("Connection was closed") - } +func (s *SocketServer) waitForClose(ctx context.Context, closeConn chan error, connID int) { + defer func() { + // Close the connection + if err := s.rmConn(connID); err != nil { + s.Logger.Error("Error closing connection", "err", err) + } + }() - // Close the connection - if err := s.rmConn(connID); err != nil { - s.Logger.Error("Error closing connection", "err", err) + select { + case <-ctx.Done(): + return + case err := <-closeConn: + switch { + case err == io.EOF: + s.Logger.Error("Connection was closed by client") + case err != nil: + s.Logger.Error("Connection error", "err", err) + default: + // never happens + s.Logger.Error("Connection was closed") + } } } // Read requests from conn and deal with them -func (s *SocketServer) handleRequests(closeConn chan error, conn io.Reader, responses chan<- *types.Response) { +func (s *SocketServer) handleRequests( + ctx context.Context, + closeConn chan error, + conn io.Reader, + responses chan<- *types.Response, +) { var count int var bufReader = bufio.NewReader(conn) @@ -163,6 +181,9 @@ func (s *SocketServer) handleRequests(closeConn chan error, conn io.Reader, resp }() for { + if ctx.Err() != nil { + return + } var req = &types.Request{} err := types.ReadMessage(bufReader, req) @@ -229,7 +250,12 @@ func (s *SocketServer) handleRequest(req *types.Request, responses chan<- *types } // Pull responses from 'responses' and write them to conn. -func (s *SocketServer) handleResponses(closeConn chan error, conn io.Writer, responses <-chan *types.Response) { +func (s *SocketServer) handleResponses( + ctx context.Context, + closeConn chan error, + conn io.Writer, + responses <-chan *types.Response, +) { bw := bufio.NewWriter(conn) for res := range responses { if err := types.WriteMessage(res, bw); err != nil { diff --git a/abci/tests/client_server_test.go b/abci/tests/client_server_test.go index 6b4750b33..2dfa68c63 100644 --- a/abci/tests/client_server_test.go +++ b/abci/tests/client_server_test.go @@ -1,6 +1,7 @@ package tests import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -12,19 +13,23 @@ import ( ) func TestClientServerNoAddrPrefix(t *testing.T) { - addr := "localhost:26658" - transport := "socket" - app := kvstore.NewApplication() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + const ( + addr = "localhost:26658" + transport = "socket" + ) + app := kvstore.NewApplication() logger := log.TestingLogger() server, err := abciserver.NewServer(logger, addr, transport, app) assert.NoError(t, err, "expected no error on NewServer") - err = server.Start() + err = server.Start(ctx) assert.NoError(t, err, "expected no error on server.Start") client, err := abciclientent.NewClient(logger, addr, transport, true) assert.NoError(t, err, "expected no error on NewClient") - err = client.Start() + err = client.Start(ctx) assert.NoError(t, err, "expected no error on client.Start") } diff --git a/abci/tests/server/client.go b/abci/tests/server/client.go index 23adbe80d..5062083f0 100644 --- a/abci/tests/server/client.go +++ b/abci/tests/server/client.go @@ -12,9 +12,7 @@ import ( tmrand "github.com/tendermint/tendermint/libs/rand" ) -var ctx = context.Background() - -func InitChain(client abciclient.Client) error { +func InitChain(ctx context.Context, client abciclient.Client) error { total := 10 vals := make([]types.ValidatorUpdate, total) for i := 0; i < total; i++ { @@ -34,7 +32,7 @@ func InitChain(client abciclient.Client) error { return nil } -func Commit(client abciclient.Client, hashExp []byte) error { +func Commit(ctx context.Context, client abciclient.Client, hashExp []byte) error { res, err := client.CommitSync(ctx) data := res.Data if err != nil { @@ -51,7 +49,7 @@ func Commit(client abciclient.Client, hashExp []byte) error { return nil } -func DeliverTx(client abciclient.Client, txBytes []byte, codeExp uint32, dataExp []byte) error { +func DeliverTx(ctx context.Context, client abciclient.Client, txBytes []byte, codeExp uint32, dataExp []byte) error { res, _ := client.DeliverTxSync(ctx, types.RequestDeliverTx{Tx: txBytes}) code, data, log := res.Code, res.Data, res.Log if code != codeExp { @@ -70,7 +68,7 @@ func DeliverTx(client abciclient.Client, txBytes []byte, codeExp uint32, dataExp return nil } -func CheckTx(client abciclient.Client, txBytes []byte, codeExp uint32, dataExp []byte) error { +func CheckTx(ctx context.Context, client abciclient.Client, txBytes []byte, codeExp uint32, dataExp []byte) error { res, _ := client.CheckTxSync(ctx, types.RequestCheckTx{Tx: txBytes}) code, data, log := res.Code, res.Data, res.Log if code != codeExp { diff --git a/cmd/tendermint/commands/light.go b/cmd/tendermint/commands/light.go index f4c9a21da..0e1894ccf 100644 --- a/cmd/tendermint/commands/light.go +++ b/cmd/tendermint/commands/light.go @@ -6,8 +6,10 @@ import ( "fmt" "net/http" "os" + "os/signal" "path/filepath" "strings" + "syscall" "time" "github.com/spf13/cobra" @@ -191,8 +193,12 @@ func runProxy(cmd *cobra.Command, args []string) error { p.Listener.Close() }) + // this might be redundant to the above, eventually. + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGTERM) + defer cancel() + logger.Info("Starting proxy...", "laddr", listenAddr) - if err := p.ListenAndServe(); err != http.ErrServerClosed { + if err := p.ListenAndServe(ctx); err != http.ErrServerClosed { // Error starting or closing listener: logger.Error("proxy ListenAndServe", "err", err) } diff --git a/cmd/tendermint/commands/replay.go b/cmd/tendermint/commands/replay.go index 558208ab3..2cd4c966a 100644 --- a/cmd/tendermint/commands/replay.go +++ b/cmd/tendermint/commands/replay.go @@ -10,8 +10,7 @@ var ReplayCmd = &cobra.Command{ Use: "replay", Short: "Replay messages from WAL", RunE: func(cmd *cobra.Command, args []string) error { - return consensus.RunReplayFile(logger, config.BaseConfig, config.Consensus, false) - + return consensus.RunReplayFile(cmd.Context(), logger, config.BaseConfig, config.Consensus, false) }, } @@ -21,6 +20,6 @@ var ReplayConsoleCmd = &cobra.Command{ Use: "replay-console", Short: "Replay messages from WAL in a console", RunE: func(cmd *cobra.Command, args []string) error { - return consensus.RunReplayFile(logger, config.BaseConfig, config.Consensus, true) + return consensus.RunReplayFile(cmd.Context(), logger, config.BaseConfig, config.Consensus, true) }, } diff --git a/cmd/tendermint/commands/run_node.go b/cmd/tendermint/commands/run_node.go index d5a3de04e..feffbc2d0 100644 --- a/cmd/tendermint/commands/run_node.go +++ b/cmd/tendermint/commands/run_node.go @@ -6,11 +6,12 @@ import ( "fmt" "io" "os" + "os/signal" + "syscall" "github.com/spf13/cobra" cfg "github.com/tendermint/tendermint/config" - tmos "github.com/tendermint/tendermint/libs/os" ) var ( @@ -103,28 +104,22 @@ func NewRunNodeCmd(nodeProvider cfg.ServiceProvider) *cobra.Command { return err } - n, err := nodeProvider(config, logger) + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGTERM) + defer cancel() + + n, err := nodeProvider(ctx, config, logger) if err != nil { return fmt.Errorf("failed to create node: %w", err) } - if err := n.Start(); err != nil { + if err := n.Start(ctx); err != nil { return fmt.Errorf("failed to start node: %w", err) } logger.Info("started node", "node", n.String()) - // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(logger, func() { - if n.IsRunning() { - if err := n.Stop(); err != nil { - logger.Error("unable to stop the node", "error", err) - } - } - }) - - // Run forever. - select {} + <-ctx.Done() + return nil }, } diff --git a/config/db.go b/config/db.go index 8f489a87a..f508354e0 100644 --- a/config/db.go +++ b/config/db.go @@ -1,6 +1,8 @@ package config import ( + "context" + dbm "github.com/tendermint/tm-db" "github.com/tendermint/tendermint/libs/log" @@ -8,7 +10,7 @@ import ( ) // ServiceProvider takes a config and a logger and returns a ready to go Node. -type ServiceProvider func(*Config, log.Logger) (service.Service, error) +type ServiceProvider func(context.Context, *Config, log.Logger) (service.Service, error) // DBContext specifies config information for loading a new DB. type DBContext struct { diff --git a/docs/architecture/adr-006-trust-metric.md b/docs/architecture/adr-006-trust-metric.md index 6fa77a609..608978207 100644 --- a/docs/architecture/adr-006-trust-metric.md +++ b/docs/architecture/adr-006-trust-metric.md @@ -178,7 +178,7 @@ type TrustMetricStore struct { } // OnStart implements Service -func (tms *TrustMetricStore) OnStart() error {} +func (tms *TrustMetricStore) OnStart(context.Context) error { return nil } // OnStop implements Service func (tms *TrustMetricStore) OnStop() {} diff --git a/internal/blocksync/pool.go b/internal/blocksync/pool.go index 66ed24a79..6f06c9883 100644 --- a/internal/blocksync/pool.go +++ b/internal/blocksync/pool.go @@ -1,6 +1,7 @@ package blocksync import ( + "context" "errors" "fmt" "math" @@ -116,15 +117,15 @@ func NewBlockPool( // OnStart implements service.Service by spawning requesters routine and recording // pool's start time. -func (pool *BlockPool) OnStart() error { +func (pool *BlockPool) OnStart(ctx context.Context) error { pool.lastAdvance = time.Now() pool.lastHundredBlockTimeStamp = pool.lastAdvance - go pool.makeRequestersRoutine() + go pool.makeRequestersRoutine(ctx) return nil } // spawns requesters as needed -func (pool *BlockPool) makeRequestersRoutine() { +func (pool *BlockPool) makeRequestersRoutine(ctx context.Context) { for { if !pool.IsRunning() { break @@ -144,7 +145,7 @@ func (pool *BlockPool) makeRequestersRoutine() { pool.removeTimedoutPeers() default: // request for more blocks. - pool.makeNextRequester() + pool.makeNextRequester(ctx) } } } @@ -397,7 +398,7 @@ func (pool *BlockPool) pickIncrAvailablePeer(height int64) *bpPeer { return nil } -func (pool *BlockPool) makeNextRequester() { +func (pool *BlockPool) makeNextRequester(ctx context.Context) { pool.mtx.Lock() defer pool.mtx.Unlock() @@ -411,7 +412,7 @@ func (pool *BlockPool) makeNextRequester() { pool.requesters[nextHeight] = request atomic.AddInt32(&pool.numPending, 1) - err := request.Start() + err := request.Start(ctx) if err != nil { request.Logger.Error("Error starting request", "err", err) } @@ -570,7 +571,7 @@ func newBPRequester(pool *BlockPool, height int64) *bpRequester { return bpr } -func (bpr *bpRequester) OnStart() error { +func (bpr *bpRequester) OnStart(ctx context.Context) error { go bpr.requestRoutine() return nil } diff --git a/internal/blocksync/pool_test.go b/internal/blocksync/pool_test.go index b53699f97..0718fee16 100644 --- a/internal/blocksync/pool_test.go +++ b/internal/blocksync/pool_test.go @@ -1,6 +1,7 @@ package blocksync import ( + "context" "fmt" mrand "math/rand" "testing" @@ -78,22 +79,20 @@ func makePeers(numPeers int, minHeight, maxHeight int64) testPeers { } func TestBlockPoolBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + start := int64(42) peers := makePeers(10, start+1, 1000) errorsCh := make(chan peerError, 1000) requestsCh := make(chan BlockRequest, 1000) pool := NewBlockPool(log.TestingLogger(), start, requestsCh, errorsCh) - err := pool.Start() - if err != nil { + if err := pool.Start(ctx); err != nil { t.Error(err) } - t.Cleanup(func() { - if err := pool.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); pool.Wait() }) peers.start() defer peers.stop() @@ -137,20 +136,19 @@ func TestBlockPoolBasic(t *testing.T) { } func TestBlockPoolTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + start := int64(42) peers := makePeers(10, start+1, 1000) errorsCh := make(chan peerError, 1000) requestsCh := make(chan BlockRequest, 1000) pool := NewBlockPool(log.TestingLogger(), start, requestsCh, errorsCh) - err := pool.Start() + err := pool.Start(ctx) if err != nil { t.Error(err) } - t.Cleanup(func() { - if err := pool.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); pool.Wait() }) for _, peer := range peers { t.Logf("Peer %v", peer.id) @@ -199,6 +197,9 @@ func TestBlockPoolTimeout(t *testing.T) { } func TestBlockPoolRemovePeer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peers := make(testPeers, 10) for i := 0; i < 10; i++ { peerID := types.NodeID(fmt.Sprintf("%d", i+1)) @@ -209,13 +210,9 @@ func TestBlockPoolRemovePeer(t *testing.T) { errorsCh := make(chan peerError) pool := NewBlockPool(log.TestingLogger(), 1, requestsCh, errorsCh) - err := pool.Start() + err := pool.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := pool.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); pool.Wait() }) // add peers for peerID, peer := range peers { diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index f18ed86b7..a6845b719 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -1,6 +1,7 @@ package blocksync import ( + "context" "fmt" "runtime/debug" "sync" @@ -49,7 +50,7 @@ func GetChannelDescriptor() *p2p.ChannelDescriptor { type consensusReactor interface { // For when we switch from block sync reactor to the consensus // machine. - SwitchToConsensus(state sm.State, skipWAL bool) + SwitchToConsensus(ctx context.Context, state sm.State, skipWAL bool) } type peerError struct { @@ -151,9 +152,9 @@ func NewReactor( // // If blockSync is enabled, we also start the pool and the pool processing // goroutine. If the pool fails to start, an error is returned. -func (r *Reactor) OnStart() error { +func (r *Reactor) OnStart(ctx context.Context) error { if r.blockSync.IsSet() { - if err := r.pool.Start(); err != nil { + if err := r.pool.Start(ctx); err != nil { return err } r.poolWG.Add(1) @@ -362,12 +363,12 @@ func (r *Reactor) processPeerUpdates() { // SwitchToBlockSync is called by the state sync reactor when switching to fast // sync. -func (r *Reactor) SwitchToBlockSync(state sm.State) error { +func (r *Reactor) SwitchToBlockSync(ctx context.Context, state sm.State) error { r.blockSync.Set() r.initialState = state r.pool.height = state.LastBlockHeight + 1 - if err := r.pool.Start(); err != nil { + if err := r.pool.Start(ctx); err != nil { return err } @@ -423,6 +424,17 @@ func (r *Reactor) requestRoutine() { } } +func (r *Reactor) stopCtx() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-r.closeCh + cancel() + }() + + return ctx +} + // poolRoutine handles messages from the poolReactor telling the reactor what to // do. // @@ -441,6 +453,7 @@ func (r *Reactor) poolRoutine(stateSynced bool) { lastRate = 0.0 didProcessCh = make(chan struct{}, 1) + ctx = r.stopCtx() ) defer trySyncTicker.Stop() @@ -488,7 +501,7 @@ FOR_LOOP: r.blockSync.UnSet() if r.consReactor != nil { - r.consReactor.SwitchToConsensus(state, blocksSynced > 0 || stateSynced) + r.consReactor.SwitchToConsensus(ctx, state, blocksSynced > 0 || stateSynced) } break FOR_LOOP diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 2b567aae7..5792d9e78 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -1,6 +1,7 @@ package blocksync import ( + "context" "os" "testing" "time" @@ -41,6 +42,7 @@ type reactorTestSuite struct { } func setup( + ctx context.Context, t *testing.T, genDoc *types.GenesisDoc, privVal types.PrivValidator, @@ -49,13 +51,16 @@ func setup( ) *reactorTestSuite { t.Helper() + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + numNodes := len(maxBlockHeights) require.True(t, numNodes >= 1, "must specify at least one block height (nodes)") rts := &reactorTestSuite{ logger: log.TestingLogger().With("module", "block_sync", "testCase", t.Name()), - network: p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: numNodes}), + network: p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: numNodes}), nodes: make([]types.NodeID, 0, numNodes), reactors: make(map[types.NodeID]*Reactor, numNodes), app: make(map[types.NodeID]proxy.AppConns, numNodes), @@ -70,17 +75,19 @@ func setup( i := 0 for nodeID := range rts.network.Nodes { - rts.addNode(t, nodeID, genDoc, privVal, maxBlockHeights[i]) + rts.addNode(ctx, t, nodeID, genDoc, privVal, maxBlockHeights[i]) i++ } t.Cleanup(func() { + cancel() for _, nodeID := range rts.nodes { rts.peerUpdates[nodeID].Close() if rts.reactors[nodeID].IsRunning() { - require.NoError(t, rts.reactors[nodeID].Stop()) - require.NoError(t, rts.app[nodeID].Stop()) + rts.reactors[nodeID].Wait() + rts.app[nodeID].Wait() + require.False(t, rts.reactors[nodeID].IsRunning()) } } @@ -89,7 +96,9 @@ func setup( return rts } -func (rts *reactorTestSuite) addNode(t *testing.T, +func (rts *reactorTestSuite) addNode( + ctx context.Context, + t *testing.T, nodeID types.NodeID, genDoc *types.GenesisDoc, privVal types.PrivValidator, @@ -101,7 +110,7 @@ func (rts *reactorTestSuite) addNode(t *testing.T, rts.nodes = append(rts.nodes, nodeID) rts.app[nodeID] = proxy.NewAppConns(abciclient.NewLocalCreator(&abci.BaseApplication{}), logger, proxy.NopMetrics()) - require.NoError(t, rts.app[nodeID].Start()) + require.NoError(t, rts.app[nodeID].Start(ctx)) blockDB := dbm.NewMemDB() stateDB := dbm.NewMemDB() @@ -170,7 +179,7 @@ func (rts *reactorTestSuite) addNode(t *testing.T, consensus.NopMetrics()) require.NoError(t, err) - require.NoError(t, rts.reactors[nodeID].Start()) + require.NoError(t, rts.reactors[nodeID].Start(ctx)) require.True(t, rts.reactors[nodeID].IsRunning()) } @@ -184,6 +193,9 @@ func (rts *reactorTestSuite) start(t *testing.T) { } func TestReactor_AbruptDisconnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("block_sync_reactor_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) @@ -191,7 +203,7 @@ func TestReactor_AbruptDisconnect(t *testing.T) { genDoc, privVals := factory.RandGenesisDoc(cfg, 1, false, 30) maxBlockHeight := int64(64) - rts := setup(t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) + rts := setup(ctx, t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) @@ -220,6 +232,9 @@ func TestReactor_AbruptDisconnect(t *testing.T) { } func TestReactor_SyncTime(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("block_sync_reactor_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) @@ -227,7 +242,7 @@ func TestReactor_SyncTime(t *testing.T) { genDoc, privVals := factory.RandGenesisDoc(cfg, 1, false, 30) maxBlockHeight := int64(101) - rts := setup(t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) + rts := setup(ctx, t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) rts.start(t) @@ -244,6 +259,9 @@ func TestReactor_SyncTime(t *testing.T) { } func TestReactor_NoBlockResponse(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("block_sync_reactor_test") require.NoError(t, err) @@ -252,7 +270,7 @@ func TestReactor_NoBlockResponse(t *testing.T) { genDoc, privVals := factory.RandGenesisDoc(cfg, 1, false, 30) maxBlockHeight := int64(65) - rts := setup(t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) + rts := setup(ctx, t, genDoc, privVals[0], []int64{maxBlockHeight, 0}, 0) require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) @@ -293,6 +311,9 @@ func TestReactor_BadBlockStopsPeer(t *testing.T) { // See: https://github.com/tendermint/tendermint/issues/6005 t.SkipNow() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("block_sync_reactor_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) @@ -300,7 +321,7 @@ func TestReactor_BadBlockStopsPeer(t *testing.T) { maxBlockHeight := int64(48) genDoc, privVals := factory.RandGenesisDoc(cfg, 1, false, 30) - rts := setup(t, genDoc, privVals[0], []int64{maxBlockHeight, 0, 0, 0, 0}, 1000) + rts := setup(ctx, t, genDoc, privVals[0], []int64{maxBlockHeight, 0, 0, 0, 0}, 1000) require.Equal(t, maxBlockHeight, rts.reactors[rts.nodes[0]].store.Height()) @@ -333,11 +354,11 @@ func TestReactor_BadBlockStopsPeer(t *testing.T) { // XXX: This causes a potential race condition. // See: https://github.com/tendermint/tendermint/issues/6005 otherGenDoc, otherPrivVals := factory.RandGenesisDoc(cfg, 1, false, 30) - newNode := rts.network.MakeNode(t, p2ptest.NodeOptions{ + newNode := rts.network.MakeNode(ctx, t, p2ptest.NodeOptions{ MaxPeers: uint16(len(rts.nodes) + 1), MaxConnected: uint16(len(rts.nodes) + 1), }) - rts.addNode(t, newNode.NodeID, otherGenDoc, otherPrivVals[0], maxBlockHeight) + rts.addNode(ctx, t, newNode.NodeID, otherGenDoc, otherPrivVals[0], maxBlockHeight) // add a fake peer just so we do not wait for the consensus ticker to timeout rts.reactors[newNode.NodeID].pool.SetPeerRange("00ff", 10, 10) diff --git a/internal/consensus/byzantine_test.go b/internal/consensus/byzantine_test.go index df11067b9..d9e5e46b8 100644 --- a/internal/consensus/byzantine_test.go +++ b/internal/consensus/byzantine_test.go @@ -31,6 +31,9 @@ import ( // Byzantine node sends two different prevotes (nil and blockID) to the same // validator. func TestByzantinePrevoteEquivocation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + config := configSetup(t) nValidators := 4 @@ -93,7 +96,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { cs.SetPrivValidator(pv) eventBus := eventbus.NewDefault(log.TestingLogger().With("module", "events")) - err = eventBus.Start() + err = eventBus.Start(ctx) require.NoError(t, err) cs.SetEventBus(eventBus) @@ -103,7 +106,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { }() } - rts := setup(t, nValidators, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, nValidators, states, 100) // buffer must be large enough to not deadlock var bzNodeID types.NodeID @@ -211,7 +214,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { propBlockID := types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()} proposal := types.NewProposal(height, round, lazyNodeState.ValidRound, propBlockID) p := proposal.ToProto() - if err := lazyNodeState.privValidator.SignProposal(context.Background(), lazyNodeState.state.ChainID, p); err == nil { + if err := lazyNodeState.privValidator.SignProposal(ctx, lazyNodeState.state.ChainID, p); err == nil { proposal.Signature = p.Signature // send proposal and block parts on internal msg queue @@ -229,15 +232,13 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } // Evidence should be submitted and committed at the third height but // we will check the first six just in case evidenceFromEachValidator := make([]types.Evidence, nValidators) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup i := 0 for _, sub := range rts.subs { @@ -246,6 +247,10 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { go func(j int, s eventbus.Subscription) { defer wg.Done() for { + if ctx.Err() != nil { + return + } + msg, err := s.Next(ctx) if !assert.NoError(t, err) { cancel() @@ -265,7 +270,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { wg.Wait() - pubkey, err := bzNodeState.privValidator.GetPubKey(context.Background()) + pubkey, err := bzNodeState.privValidator.GetPubKey(ctx) require.NoError(t, err) for idx, ev := range evidenceFromEachValidator { @@ -311,7 +316,7 @@ func TestByzantineConflictingProposalsWithPartition(t *testing.T) { // eventBus.SetLogger(logger.With("module", "events", "validator", i)) // var err error - // blocksSubs[i], err = eventBus.Subscribe(context.Background(), testSubscriber, types.EventQueryNewBlock) + // blocksSubs[i], err = eventBus.Subscribe(ctx, testSubscriber, types.EventQueryNewBlock) // require.NoError(t, err) // conR := NewReactor(states[i], true) // so we don't start the consensus states diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index 780ec8804..bcbbe7c88 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -106,12 +106,14 @@ func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *valida } func (vs *validatorStub) signVote( + ctx context.Context, cfg *config.Config, voteType tmproto.SignedMsgType, hash []byte, - header types.PartSetHeader) (*types.Vote, error) { + header types.PartSetHeader, +) (*types.Vote, error) { - pubKey, err := vs.PrivValidator.GetPubKey(context.Background()) + pubKey, err := vs.PrivValidator.GetPubKey(ctx) if err != nil { return nil, fmt.Errorf("can't get pubkey: %w", err) } @@ -126,7 +128,7 @@ func (vs *validatorStub) signVote( BlockID: types.BlockID{Hash: hash, PartSetHeader: header}, } v := vote.ToProto() - if err := vs.PrivValidator.SignVote(context.Background(), cfg.ChainID(), v); err != nil { + if err := vs.PrivValidator.SignVote(ctx, cfg.ChainID(), v); err != nil { return nil, fmt.Errorf("sign vote failed: %w", err) } @@ -144,13 +146,15 @@ func (vs *validatorStub) signVote( // Sign vote for type/hash/header func signVote( + ctx context.Context, vs *validatorStub, cfg *config.Config, voteType tmproto.SignedMsgType, hash []byte, - header types.PartSetHeader) *types.Vote { + header types.PartSetHeader, +) *types.Vote { - v, err := vs.signVote(cfg, voteType, hash, header) + v, err := vs.signVote(ctx, cfg, voteType, hash, header) if err != nil { panic(fmt.Errorf("failed to sign vote: %v", err)) } @@ -161,6 +165,7 @@ func signVote( } func signVotes( + ctx context.Context, cfg *config.Config, voteType tmproto.SignedMsgType, hash []byte, @@ -168,7 +173,7 @@ func signVotes( vss ...*validatorStub) []*types.Vote { votes := make([]*types.Vote, len(vss)) for i, vs := range vss { - votes[i] = signVote(vs, cfg, voteType, hash, header) + votes[i] = signVote(ctx, vs, cfg, voteType, hash, header) } return votes } @@ -192,11 +197,11 @@ func (vss ValidatorStubsByPower) Len() int { } func (vss ValidatorStubsByPower) Less(i, j int) bool { - vssi, err := vss[i].GetPubKey(context.Background()) + vssi, err := vss[i].GetPubKey(context.TODO()) if err != nil { panic(err) } - vssj, err := vss[j].GetPubKey(context.Background()) + vssj, err := vss[j].GetPubKey(context.TODO()) if err != nil { panic(err) } @@ -218,13 +223,14 @@ func (vss ValidatorStubsByPower) Swap(i, j int) { //------------------------------------------------------------------------------- // Functions for transitioning the consensus state -func startTestRound(cs *State, height int64, round int32) { +func startTestRound(ctx context.Context, cs *State, height int64, round int32) { cs.enterNewRound(height, round) - cs.startRoutines(0) + cs.startRoutines(ctx, 0) } // Create proposal block from cs1 but sign it with vs. func decideProposal( + ctx context.Context, cs1 *State, vs *validatorStub, height int64, @@ -243,7 +249,7 @@ func decideProposal( polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()} proposal = types.NewProposal(height, round, polRound, propBlockID) p := proposal.ToProto() - if err := vs.SignProposal(context.Background(), chainID, p); err != nil { + if err := vs.SignProposal(ctx, chainID, p); err != nil { panic(err) } @@ -259,6 +265,7 @@ func addVotes(to *State, votes ...*types.Vote) { } func signAddVotes( + ctx context.Context, cfg *config.Config, to *State, voteType tmproto.SignedMsgType, @@ -266,13 +273,19 @@ func signAddVotes( header types.PartSetHeader, vss ...*validatorStub, ) { - votes := signVotes(cfg, voteType, hash, header, vss...) - addVotes(to, votes...) + addVotes(to, signVotes(ctx, cfg, voteType, hash, header, vss...)...) } -func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) { +func validatePrevote( + ctx context.Context, + t *testing.T, + cs *State, + round int32, + privVal *validatorStub, + blockHash []byte, +) { prevotes := cs.Votes.Prevotes(round) - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) require.NoError(t, err) address := pubKey.Address() var vote *types.Vote @@ -290,9 +303,9 @@ func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStu } } -func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) { +func validateLastPrecommit(ctx context.Context, t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) { votes := cs.LastCommit - pv, err := privVal.GetPubKey(context.Background()) + pv, err := privVal.GetPubKey(ctx) require.NoError(t, err) address := pv.Address() var vote *types.Vote @@ -305,6 +318,7 @@ func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, bloc } func validatePrecommit( + ctx context.Context, t *testing.T, cs *State, thisRound, @@ -314,7 +328,7 @@ func validatePrecommit( lockedBlockHash []byte, ) { precommits := cs.Votes.Precommits(thisRound) - pv, err := privVal.GetPubKey(context.Background()) + pv, err := privVal.GetPubKey(ctx) require.NoError(t, err) address := pv.Address() var vote *types.Vote @@ -353,6 +367,7 @@ func validatePrecommit( } func validatePrevoteAndPrecommit( + ctx context.Context, t *testing.T, cs *State, thisRound, @@ -362,18 +377,18 @@ func validatePrevoteAndPrecommit( lockedBlockHash []byte, ) { // verify the prevote - validatePrevote(t, cs, thisRound, privVal, votedBlockHash) + validatePrevote(ctx, t, cs, thisRound, privVal, votedBlockHash) // verify precommit cs.mtx.Lock() - validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash) - cs.mtx.Unlock() + defer cs.mtx.Unlock() + validatePrecommit(ctx, t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash) } -func subscribeToVoter(t *testing.T, cs *State, addr []byte) <-chan tmpubsub.Message { +func subscribeToVoter(ctx context.Context, t *testing.T, cs *State, addr []byte) <-chan tmpubsub.Message { t.Helper() ch := make(chan tmpubsub.Message, 1) - if err := cs.eventBus.Observe(context.Background(), func(msg tmpubsub.Message) error { + if err := cs.eventBus.Observe(ctx, func(msg tmpubsub.Message) error { vote := msg.Data().(types.EventDataVote) // we only fire for our own votes if bytes.Equal(addr, vote.Vote.ValidatorAddress) { @@ -389,27 +404,34 @@ func subscribeToVoter(t *testing.T, cs *State, addr []byte) <-chan tmpubsub.Mess //------------------------------------------------------------------------------- // consensus states -func newState(logger log.Logger, state sm.State, pv types.PrivValidator, app abci.Application) (*State, error) { +func newState( + ctx context.Context, + logger log.Logger, + state sm.State, + pv types.PrivValidator, + app abci.Application, +) (*State, error) { cfg, err := config.ResetTestRoot("consensus_state_test") if err != nil { return nil, err } - return newStateWithConfig(logger, cfg, state, pv, app), nil + return newStateWithConfig(ctx, logger, cfg, state, pv, app), nil } func newStateWithConfig( + ctx context.Context, logger log.Logger, thisConfig *config.Config, state sm.State, pv types.PrivValidator, app abci.Application, ) *State { - blockStore := store.NewBlockStore(dbm.NewMemDB()) - return newStateWithConfigAndBlockStore(logger, thisConfig, state, pv, app, blockStore) + return newStateWithConfigAndBlockStore(ctx, logger, thisConfig, state, pv, app, store.NewBlockStore(dbm.NewMemDB())) } func newStateWithConfigAndBlockStore( + ctx context.Context, logger log.Logger, thisConfig *config.Config, state sm.State, @@ -449,7 +471,7 @@ func newStateWithConfigAndBlockStore( cs.SetPrivValidator(pv) eventBus := eventbus.NewDefault(logger.With("module", "events")) - err := eventBus.Start() + err := eventBus.Start(ctx) if err != nil { panic(err) } @@ -469,13 +491,18 @@ func loadPrivValidator(cfg *config.Config) *privval.FilePV { return privValidator } -func randState(cfg *config.Config, logger log.Logger, nValidators int) (*State, []*validatorStub, error) { +func randState( + ctx context.Context, + cfg *config.Config, + logger log.Logger, + nValidators int, +) (*State, []*validatorStub, error) { // Get State state, privVals := randGenesisState(cfg, nValidators, false, 10) vss := make([]*validatorStub, nValidators) - cs, err := newState(logger, state, privVals[0], kvstore.NewApplication()) + cs, err := newState(ctx, logger, state, privVals[0], kvstore.NewApplication()) if err != nil { return nil, nil, err } @@ -719,6 +746,7 @@ func consensusLogger() log.Logger { } func randConsensusState( + ctx context.Context, t *testing.T, cfg *config.Config, nValidators int, @@ -761,7 +789,7 @@ func randConsensusState( app.InitChain(abci.RequestInitChain{Validators: vals}) l := logger.With("validator", i, "module", "consensus") - css[i] = newStateWithConfigAndBlockStore(l, thisConfig, state, privVals[i], app, blockStore) + css[i] = newStateWithConfigAndBlockStore(ctx, l, thisConfig, state, privVals[i], app, blockStore) css[i].SetTimeoutTicker(tickerFunc()) } @@ -777,6 +805,7 @@ func randConsensusState( // nPeers = nValidators + nNotValidator func randConsensusNetWithPeers( + ctx context.Context, cfg *config.Config, nValidators, nPeers int, @@ -830,7 +859,7 @@ func randConsensusNetWithPeers( app.InitChain(abci.RequestInitChain{Validators: vals}) // sm.SaveState(stateDB,state) //height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above - css[i] = newStateWithConfig(logger.With("validator", i, "module", "consensus"), thisConfig, state, privVal, app) + css[i] = newStateWithConfig(ctx, logger.With("validator", i, "module", "consensus"), thisConfig, state, privVal, app) css[i].SetTimeoutTicker(tickerFunc()) } return css, genDoc, peer0Config, func() { @@ -870,7 +899,7 @@ type mockTicker struct { fired bool } -func (m *mockTicker) Start() error { +func (m *mockTicker) Start(context.Context) error { return nil } diff --git a/internal/consensus/invalid_test.go b/internal/consensus/invalid_test.go index fae89f227..fc872a9fa 100644 --- a/internal/consensus/invalid_test.go +++ b/internal/consensus/invalid_test.go @@ -17,10 +17,13 @@ import ( ) func TestReactorInvalidPrecommit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + config := configSetup(t) n := 4 - states, cleanup := randConsensusState(t, + states, cleanup := randConsensusState(ctx, t, config, n, "consensus_reactor_test", newMockTickerFunc(true), newKVStore) t.Cleanup(cleanup) @@ -30,11 +33,11 @@ func TestReactorInvalidPrecommit(t *testing.T) { states[i].SetTimeoutTicker(ticker) } - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } // this val sends a random precommit at each height @@ -48,7 +51,7 @@ func TestReactorInvalidPrecommit(t *testing.T) { byzState.mtx.Lock() privVal := byzState.privValidator byzState.doPrevote = func(height int64, round int32) { - invalidDoPrevoteFunc(t, height, round, byzState, byzReactor, privVal) + invalidDoPrevoteFunc(ctx, t, height, round, byzState, byzReactor, privVal) } byzState.mtx.Unlock() @@ -56,8 +59,7 @@ func TestReactorInvalidPrecommit(t *testing.T) { // // TODO: Make this tighter by ensuring the halt happens by block 2. var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + for i := 0; i < 10; i++ { for _, sub := range rts.subs { wg.Add(1) @@ -75,7 +77,15 @@ func TestReactorInvalidPrecommit(t *testing.T) { wg.Wait() } -func invalidDoPrevoteFunc(t *testing.T, height int64, round int32, cs *State, r *Reactor, pv types.PrivValidator) { +func invalidDoPrevoteFunc( + ctx context.Context, + t *testing.T, + height int64, + round int32, + cs *State, + r *Reactor, + pv types.PrivValidator, +) { // routine to: // - precommit for a random block // - send precommit to all peers @@ -84,7 +94,7 @@ func invalidDoPrevoteFunc(t *testing.T, height int64, round int32, cs *State, r cs.mtx.Lock() cs.privValidator = pv - pubKey, err := cs.privValidator.GetPubKey(context.Background()) + pubKey, err := cs.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pubKey.Address() @@ -105,7 +115,7 @@ func invalidDoPrevoteFunc(t *testing.T, height int64, round int32, cs *State, r } p := precommit.ToProto() - err = cs.privValidator.SignVote(context.Background(), cs.state.ChainID, p) + err = cs.privValidator.SignVote(ctx, cs.state.ChainID, p) require.NoError(t, err) precommit.Signature = p.Signature diff --git a/internal/consensus/mempool_test.go b/internal/consensus/mempool_test.go index 78f42f993..e24cb1160 100644 --- a/internal/consensus/mempool_test.go +++ b/internal/consensus/mempool_test.go @@ -27,6 +27,9 @@ func assertMempool(txn txNotifier) mempool.Mempool { } func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + baseConfig := configSetup(t) config, err := ResetConfig("consensus_mempool_txs_available_test") @@ -35,15 +38,15 @@ func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { config.Consensus.CreateEmptyBlocks = false state, privVals := randGenesisState(baseConfig, 1, false, 10) - cs := newStateWithConfig(log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) + cs := newStateWithConfig(ctx, log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) assertMempool(cs.txNotifier).EnableTxsAvailable() height, round := cs.Height, cs.Round - newBlockCh := subscribe(t, cs.eventBus, types.EventQueryNewBlock) - startTestRound(cs, height, round) + newBlockCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlock) + startTestRound(ctx, cs, height, round) ensureNewEventOnChannel(newBlockCh) // first block gets committed ensureNoNewEventOnChannel(newBlockCh) - deliverTxsRange(cs, 0, 1) + deliverTxsRange(ctx, cs, 0, 1) ensureNewEventOnChannel(newBlockCh) // commit txs ensureNewEventOnChannel(newBlockCh) // commit updated app hash ensureNoNewEventOnChannel(newBlockCh) @@ -51,6 +54,8 @@ func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { baseConfig := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config, err := ResetConfig("consensus_mempool_txs_available_test") require.NoError(t, err) @@ -58,12 +63,12 @@ func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { config.Consensus.CreateEmptyBlocksInterval = ensureTimeout state, privVals := randGenesisState(baseConfig, 1, false, 10) - cs := newStateWithConfig(log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) + cs := newStateWithConfig(ctx, log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) assertMempool(cs.txNotifier).EnableTxsAvailable() - newBlockCh := subscribe(t, cs.eventBus, types.EventQueryNewBlock) - startTestRound(cs, cs.Height, cs.Round) + newBlockCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlock) + startTestRound(ctx, cs, cs.Height, cs.Round) ensureNewEventOnChannel(newBlockCh) // first block gets committed ensureNoNewEventOnChannel(newBlockCh) // then we dont make a block ... @@ -72,6 +77,8 @@ func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { func TestMempoolProgressInHigherRound(t *testing.T) { baseConfig := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config, err := ResetConfig("consensus_mempool_txs_available_test") require.NoError(t, err) @@ -79,12 +86,12 @@ func TestMempoolProgressInHigherRound(t *testing.T) { config.Consensus.CreateEmptyBlocks = false state, privVals := randGenesisState(baseConfig, 1, false, 10) - cs := newStateWithConfig(log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) + cs := newStateWithConfig(ctx, log.TestingLogger(), config, state, privVals[0], NewCounterApplication()) assertMempool(cs.txNotifier).EnableTxsAvailable() height, round := cs.Height, cs.Round - newBlockCh := subscribe(t, cs.eventBus, types.EventQueryNewBlock) - newRoundCh := subscribe(t, cs.eventBus, types.EventQueryNewRound) - timeoutCh := subscribe(t, cs.eventBus, types.EventQueryTimeoutPropose) + newBlockCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlock) + newRoundCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewRound) + timeoutCh := subscribe(ctx, t, cs.eventBus, types.EventQueryTimeoutPropose) cs.setProposal = func(proposal *types.Proposal) error { if cs.Height == 2 && cs.Round == 0 { // dont set the proposal in round 0 so we timeout and @@ -94,7 +101,7 @@ func TestMempoolProgressInHigherRound(t *testing.T) { } return cs.defaultSetProposal(proposal) } - startTestRound(cs, height, round) + startTestRound(ctx, cs, height, round) ensureNewRound(newRoundCh, height, round) // first round at first height ensureNewEventOnChannel(newBlockCh) // first block gets committed @@ -103,7 +110,7 @@ func TestMempoolProgressInHigherRound(t *testing.T) { round = 0 ensureNewRound(newRoundCh, height, round) // first round at next height - deliverTxsRange(cs, 0, 1) // we deliver txs, but dont set a proposal so we get the next round + deliverTxsRange(ctx, cs, 0, 1) // we deliver txs, but dont set a proposal so we get the next round ensureNewTimeout(timeoutCh, height, round, cs.config.TimeoutPropose.Nanoseconds()) round++ // moving to the next round @@ -111,12 +118,12 @@ func TestMempoolProgressInHigherRound(t *testing.T) { ensureNewEventOnChannel(newBlockCh) // now we can commit the block } -func deliverTxsRange(cs *State, start, end int) { +func deliverTxsRange(ctx context.Context, cs *State, start, end int) { // Deliver some txs. for i := start; i < end; i++ { txBytes := make([]byte, 8) binary.BigEndian.PutUint64(txBytes, uint64(i)) - err := assertMempool(cs.txNotifier).CheckTx(context.Background(), txBytes, nil, mempool.TxInfo{}) + err := assertMempool(cs.txNotifier).CheckTx(ctx, txBytes, nil, mempool.TxInfo{}) if err != nil { panic(fmt.Sprintf("Error after CheckTx: %v", err)) } @@ -124,6 +131,9 @@ func deliverTxsRange(cs *State, start, end int) { } func TestMempoolTxConcurrentWithCommit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + config := configSetup(t) logger := log.TestingLogger() state, privVals := randGenesisState(config, 1, false, 10) @@ -131,16 +141,17 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { blockStore := store.NewBlockStore(dbm.NewMemDB()) cs := newStateWithConfigAndBlockStore( + ctx, logger, config, state, privVals[0], NewCounterApplication(), blockStore) err := stateStore.Save(state) require.NoError(t, err) - newBlockHeaderCh := subscribe(t, cs.eventBus, types.EventQueryNewBlockHeader) + newBlockHeaderCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlockHeader) const numTxs int64 = 3000 - go deliverTxsRange(cs, 0, int(numTxs)) + go deliverTxsRange(ctx, cs, 0, int(numTxs)) - startTestRound(cs, cs.Height, cs.Round) + startTestRound(ctx, cs, cs.Height, cs.Round) for n := int64(0); n < numTxs; { select { case msg := <-newBlockHeaderCh: @@ -154,12 +165,14 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { func TestMempoolRmBadTx(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() state, privVals := randGenesisState(config, 1, false, 10) app := NewCounterApplication() stateStore := sm.NewStore(dbm.NewMemDB()) blockStore := store.NewBlockStore(dbm.NewMemDB()) - cs := newStateWithConfigAndBlockStore(log.TestingLogger(), config, state, privVals[0], app, blockStore) + cs := newStateWithConfigAndBlockStore(ctx, log.TestingLogger(), config, state, privVals[0], app, blockStore) err := stateStore.Save(state) require.NoError(t, err) @@ -179,7 +192,7 @@ func TestMempoolRmBadTx(t *testing.T) { // Try to send the tx through the mempool. // CheckTx should not err, but the app should return a bad abci code // and the tx should get removed from the pool - err := assertMempool(cs.txNotifier).CheckTx(context.Background(), txBytes, func(r *abci.Response) { + err := assertMempool(cs.txNotifier).CheckTx(ctx, txBytes, func(r *abci.Response) { if r.GetCheckTx().Code != code.CodeTypeBadNonce { t.Errorf("expected checktx to return bad nonce, got %v", r) return diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index a28f54bf9..03865d13d 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -1,6 +1,7 @@ package consensus import ( + "context" "fmt" "runtime/debug" "time" @@ -85,7 +86,7 @@ type ReactorOption func(*Reactor) // NOTE: Temporary interface for switching to block sync, we should get rid of v0. // See: https://github.com/tendermint/tendermint/issues/4595 type BlockSyncReactor interface { - SwitchToBlockSync(sm.State) error + SwitchToBlockSync(context.Context, sm.State) error GetMaxPeerBlockHeight() int64 @@ -174,7 +175,7 @@ func NewReactor( // envelopes on each. In addition, it also listens for peer updates and handles // messages on that p2p channel accordingly. The caller must be sure to execute // OnStop to ensure the outbound p2p Channels are closed. -func (r *Reactor) OnStart() error { +func (r *Reactor) OnStart(ctx context.Context) error { r.Logger.Debug("consensus wait sync", "wait_sync", r.WaitSync()) // start routine that computes peer statistics for evaluating peer quality @@ -186,7 +187,7 @@ func (r *Reactor) OnStart() error { r.subscribeToBroadcastEvents() if !r.WaitSync() { - if err := r.state.Start(); err != nil { + if err := r.state.Start(ctx); err != nil { return err } } @@ -264,7 +265,7 @@ func ReactorMetrics(metrics *Metrics) ReactorOption { // SwitchToConsensus switches from block-sync mode to consensus mode. It resets // the state, turns off block-sync, and starts the consensus state-machine. -func (r *Reactor) SwitchToConsensus(state sm.State, skipWAL bool) { +func (r *Reactor) SwitchToConsensus(ctx context.Context, state sm.State, skipWAL bool) { r.Logger.Info("switching to consensus") // we have no votes, so reconstruct LastCommit from SeenCommit @@ -287,7 +288,7 @@ func (r *Reactor) SwitchToConsensus(state sm.State, skipWAL bool) { r.state.doWALCatchup = false } - if err := r.state.Start(); err != nil { + if err := r.state.Start(ctx); err != nil { panic(fmt.Sprintf(`failed to start consensus state: %v conS: diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 8fc562b06..de6465f23 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -59,11 +59,17 @@ func chDesc(chID p2p.ChannelID, size int) *p2p.ChannelDescriptor { } } -func setup(t *testing.T, numNodes int, states []*State, size int) *reactorTestSuite { +func setup( + ctx context.Context, + t *testing.T, + numNodes int, + states []*State, + size int, +) *reactorTestSuite { t.Helper() rts := &reactorTestSuite{ - network: p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: numNodes}), + network: p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: numNodes}), states: make(map[types.NodeID]*State), reactors: make(map[types.NodeID]*Reactor, numNodes), subs: make(map[types.NodeID]eventbus.Subscription, numNodes), @@ -75,7 +81,7 @@ func setup(t *testing.T, numNodes int, states []*State, size int) *reactorTestSu rts.voteChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteChannel, size)) rts.voteSetBitsChannels = rts.network.MakeChannelsNoCleanup(t, chDesc(VoteSetBitsChannel, size)) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) // Canceled during cleanup (see below). i := 0 @@ -89,7 +95,7 @@ func setup(t *testing.T, numNodes int, states []*State, size int) *reactorTestSu rts.dataChannels[nodeID], rts.voteChannels[nodeID], rts.voteSetBitsChannels[nodeID], - node.MakePeerUpdates(t), + node.MakePeerUpdates(ctx, t), true, ) @@ -119,7 +125,7 @@ func setup(t *testing.T, numNodes int, states []*State, size int) *reactorTestSu require.NoError(t, state.blockExec.Store().Save(state.state)) } - require.NoError(t, reactor.Start()) + require.NoError(t, reactor.Start(ctx)) require.True(t, reactor.IsRunning()) i++ @@ -131,14 +137,8 @@ func setup(t *testing.T, numNodes int, states []*State, size int) *reactorTestSu rts.network.Start(t) t.Cleanup(func() { - for nodeID, r := range rts.reactors { - require.NoError(t, rts.states[nodeID].eventBus.Stop()) - require.NoError(t, r.Stop()) - require.False(t, r.IsRunning()) - } - - leaktest.Check(t) cancel() + leaktest.Check(t) }) return rts @@ -162,6 +162,7 @@ func validateBlock(block *types.Block, activeVals map[string]struct{}) error { } func waitForAndValidateBlock( + bctx context.Context, t *testing.T, n int, activeVals map[string]struct{}, @@ -169,8 +170,9 @@ func waitForAndValidateBlock( states []*State, txs ...[]byte, ) { + t.Helper() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(bctx) defer cancel() fn := func(j int) { msg, err := blocksSubs[j].Next(ctx) @@ -183,7 +185,7 @@ func waitForAndValidateBlock( require.NoError(t, validateBlock(newBlock, activeVals)) for _, tx := range txs { - require.NoError(t, assertMempool(states[j].txNotifier).CheckTx(context.Background(), tx, nil, mempool.TxInfo{})) + require.NoError(t, assertMempool(states[j].txNotifier).CheckTx(ctx, tx, nil, mempool.TxInfo{})) } } @@ -200,6 +202,7 @@ func waitForAndValidateBlock( } func waitForAndValidateBlockWithTx( + bctx context.Context, t *testing.T, n int, activeVals map[string]struct{}, @@ -207,8 +210,9 @@ func waitForAndValidateBlockWithTx( states []*State, txs ...[]byte, ) { + t.Helper() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(bctx) defer cancel() fn := func(j int) { ntxs := 0 @@ -249,15 +253,17 @@ func waitForAndValidateBlockWithTx( } func waitForBlockWithUpdatedValsAndValidateIt( + bctx context.Context, t *testing.T, n int, updatedVals map[string]struct{}, blocksSubs []eventbus.Subscription, css []*State, ) { - - ctx, cancel := context.WithCancel(context.Background()) + t.Helper() + ctx, cancel := context.WithCancel(bctx) defer cancel() + fn := func(j int) { var newBlock *types.Block @@ -299,23 +305,24 @@ func ensureBlockSyncStatus(t *testing.T, msg tmpubsub.Message, complete bool, he } func TestReactorBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) n := 4 - states, cleanup := randConsensusState(t, + states, cleanup := randConsensusState(ctx, t, cfg, n, "consensus_reactor_test", newMockTickerFunc(true), newKVStore) t.Cleanup(cleanup) - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -351,6 +358,9 @@ func TestReactorBasic(t *testing.T) { } func TestReactorWithEvidence(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) n := 4 @@ -416,7 +426,7 @@ func TestReactorWithEvidence(t *testing.T) { cs.SetPrivValidator(pv) eventBus := eventbus.NewDefault(log.TestingLogger().With("module", "events")) - require.NoError(t, eventBus.Start()) + require.NoError(t, eventBus.Start(ctx)) cs.SetEventBus(eventBus) cs.SetTimeoutTicker(tickerFunc()) @@ -424,15 +434,13 @@ func TestReactorWithEvidence(t *testing.T) { states[i] = cs } - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -456,10 +464,14 @@ func TestReactorWithEvidence(t *testing.T) { } func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) n := 4 states, cleanup := randConsensusState( + ctx, t, cfg, n, @@ -473,26 +485,24 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { t.Cleanup(cleanup) - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } // send a tx require.NoError( t, assertMempool(states[3].txNotifier).CheckTx( - context.Background(), + ctx, []byte{1, 2, 3}, nil, mempool.TxInfo{}, ), ) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -511,23 +521,24 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { } func TestReactorRecordsVotesAndBlockParts(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) n := 4 - states, cleanup := randConsensusState(t, + states, cleanup := randConsensusState(ctx, t, cfg, n, "consensus_reactor_test", newMockTickerFunc(true), newKVStore) t.Cleanup(cleanup) - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -575,10 +586,14 @@ func TestReactorRecordsVotesAndBlockParts(t *testing.T) { } func TestReactorVotingPowerChange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) n := 4 states, cleanup := randConsensusState( + ctx, t, cfg, n, @@ -589,25 +604,23 @@ func TestReactorVotingPowerChange(t *testing.T) { t.Cleanup(cleanup) - rts := setup(t, n, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, n, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } // map of active validators activeVals := make(map[string]struct{}) for i := 0; i < n; i++ { - pubKey, err := states[i].privValidator.GetPubKey(context.Background()) + pubKey, err := states[i].privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pubKey.Address() activeVals[string(addr)] = struct{}{} } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -629,7 +642,7 @@ func TestReactorVotingPowerChange(t *testing.T) { blocksSubs = append(blocksSubs, sub) } - val1PubKey, err := states[0].privValidator.GetPubKey(context.Background()) + val1PubKey, err := states[0].privValidator.GetPubKey(ctx) require.NoError(t, err) val1PubKeyABCI, err := encoding.PubKeyToProto(val1PubKey) @@ -638,10 +651,10 @@ func TestReactorVotingPowerChange(t *testing.T) { updateValidatorTx := kvstore.MakeValSetChangeTx(val1PubKeyABCI, 25) previousTotalVotingPower := states[0].GetRoundState().LastValidators.TotalVotingPower() - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlockWithTx(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlockWithTx(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) require.NotEqualf( t, previousTotalVotingPower, states[0].GetRoundState().LastValidators.TotalVotingPower(), @@ -653,10 +666,10 @@ func TestReactorVotingPowerChange(t *testing.T) { updateValidatorTx = kvstore.MakeValSetChangeTx(val1PubKeyABCI, 2) previousTotalVotingPower = states[0].GetRoundState().LastValidators.TotalVotingPower() - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlockWithTx(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlockWithTx(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) require.NotEqualf( t, states[0].GetRoundState().LastValidators.TotalVotingPower(), previousTotalVotingPower, @@ -667,10 +680,10 @@ func TestReactorVotingPowerChange(t *testing.T) { updateValidatorTx = kvstore.MakeValSetChangeTx(val1PubKeyABCI, 26) previousTotalVotingPower = states[0].GetRoundState().LastValidators.TotalVotingPower() - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlockWithTx(t, n, activeVals, blocksSubs, states, updateValidatorTx) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) - waitForAndValidateBlock(t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlockWithTx(ctx, t, n, activeVals, blocksSubs, states, updateValidatorTx) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, n, activeVals, blocksSubs, states) require.NotEqualf( t, previousTotalVotingPower, states[0].GetRoundState().LastValidators.TotalVotingPower(), @@ -681,11 +694,15 @@ func TestReactorVotingPowerChange(t *testing.T) { } func TestReactorValidatorSetChanges(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := configSetup(t) nPeers := 7 nVals := 4 states, _, _, cleanup := randConsensusNetWithPeers( + ctx, cfg, nVals, nPeers, @@ -695,11 +712,11 @@ func TestReactorValidatorSetChanges(t *testing.T) { ) t.Cleanup(cleanup) - rts := setup(t, nPeers, states, 100) // buffer must be large enough to not deadlock + rts := setup(ctx, t, nPeers, states, 100) // buffer must be large enough to not deadlock for _, reactor := range rts.reactors { state := reactor.state.GetState() - reactor.SwitchToConsensus(state, false) + reactor.SwitchToConsensus(ctx, state, false) } // map of active validators @@ -711,8 +728,6 @@ func TestReactorValidatorSetChanges(t *testing.T) { activeVals[string(pubKey.Address())] = struct{}{} } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() var wg sync.WaitGroup for _, sub := range rts.subs { wg.Add(1) @@ -729,7 +744,7 @@ func TestReactorValidatorSetChanges(t *testing.T) { wg.Wait() - newValidatorPubKey1, err := states[nVals].privValidator.GetPubKey(context.Background()) + newValidatorPubKey1, err := states[nVals].privValidator.GetPubKey(ctx) require.NoError(t, err) valPubKey1ABCI, err := encoding.PubKeyToProto(newValidatorPubKey1) @@ -745,24 +760,24 @@ func TestReactorValidatorSetChanges(t *testing.T) { // wait till everyone makes block 2 // ensure the commit includes all validators // send newValTx to change vals in block 3 - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states, newValidatorTx1) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states, newValidatorTx1) // wait till everyone makes block 3. // it includes the commit for block 2, which is by the original validator set - waitForAndValidateBlockWithTx(t, nPeers, activeVals, blocksSubs, states, newValidatorTx1) + waitForAndValidateBlockWithTx(ctx, t, nPeers, activeVals, blocksSubs, states, newValidatorTx1) // wait till everyone makes block 4. // it includes the commit for block 3, which is by the original validator set - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states) // the commits for block 4 should be with the updated validator set activeVals[string(newValidatorPubKey1.Address())] = struct{}{} // wait till everyone makes block 5 // it includes the commit for block 4, which should have the updated validator set - waitForBlockWithUpdatedValsAndValidateIt(t, nPeers, activeVals, blocksSubs, states) + waitForBlockWithUpdatedValsAndValidateIt(ctx, t, nPeers, activeVals, blocksSubs, states) - updateValidatorPubKey1, err := states[nVals].privValidator.GetPubKey(context.Background()) + updateValidatorPubKey1, err := states[nVals].privValidator.GetPubKey(ctx) require.NoError(t, err) updatePubKey1ABCI, err := encoding.PubKeyToProto(updateValidatorPubKey1) @@ -771,10 +786,10 @@ func TestReactorValidatorSetChanges(t *testing.T) { updateValidatorTx1 := kvstore.MakeValSetChangeTx(updatePubKey1ABCI, 25) previousTotalVotingPower := states[nVals].GetRoundState().LastValidators.TotalVotingPower() - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states, updateValidatorTx1) - waitForAndValidateBlockWithTx(t, nPeers, activeVals, blocksSubs, states, updateValidatorTx1) - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states) - waitForBlockWithUpdatedValsAndValidateIt(t, nPeers, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states, updateValidatorTx1) + waitForAndValidateBlockWithTx(ctx, t, nPeers, activeVals, blocksSubs, states, updateValidatorTx1) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states) + waitForBlockWithUpdatedValsAndValidateIt(ctx, t, nPeers, activeVals, blocksSubs, states) require.NotEqualf( t, states[nVals].GetRoundState().LastValidators.TotalVotingPower(), previousTotalVotingPower, @@ -782,7 +797,7 @@ func TestReactorValidatorSetChanges(t *testing.T) { previousTotalVotingPower, states[nVals].GetRoundState().LastValidators.TotalVotingPower(), ) - newValidatorPubKey2, err := states[nVals+1].privValidator.GetPubKey(context.Background()) + newValidatorPubKey2, err := states[nVals+1].privValidator.GetPubKey(ctx) require.NoError(t, err) newVal2ABCI, err := encoding.PubKeyToProto(newValidatorPubKey2) @@ -790,7 +805,7 @@ func TestReactorValidatorSetChanges(t *testing.T) { newValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, testMinPower) - newValidatorPubKey3, err := states[nVals+2].privValidator.GetPubKey(context.Background()) + newValidatorPubKey3, err := states[nVals+2].privValidator.GetPubKey(ctx) require.NoError(t, err) newVal3ABCI, err := encoding.PubKeyToProto(newValidatorPubKey3) @@ -798,24 +813,24 @@ func TestReactorValidatorSetChanges(t *testing.T) { newValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, testMinPower) - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states, newValidatorTx2, newValidatorTx3) - waitForAndValidateBlockWithTx(t, nPeers, activeVals, blocksSubs, states, newValidatorTx2, newValidatorTx3) - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states, newValidatorTx2, newValidatorTx3) + waitForAndValidateBlockWithTx(ctx, t, nPeers, activeVals, blocksSubs, states, newValidatorTx2, newValidatorTx3) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states) activeVals[string(newValidatorPubKey2.Address())] = struct{}{} activeVals[string(newValidatorPubKey3.Address())] = struct{}{} - waitForBlockWithUpdatedValsAndValidateIt(t, nPeers, activeVals, blocksSubs, states) + waitForBlockWithUpdatedValsAndValidateIt(ctx, t, nPeers, activeVals, blocksSubs, states) removeValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, 0) removeValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, 0) - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states, removeValidatorTx2, removeValidatorTx3) - waitForAndValidateBlockWithTx(t, nPeers, activeVals, blocksSubs, states, removeValidatorTx2, removeValidatorTx3) - waitForAndValidateBlock(t, nPeers, activeVals, blocksSubs, states) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states, removeValidatorTx2, removeValidatorTx3) + waitForAndValidateBlockWithTx(ctx, t, nPeers, activeVals, blocksSubs, states, removeValidatorTx2, removeValidatorTx3) + waitForAndValidateBlock(ctx, t, nPeers, activeVals, blocksSubs, states) delete(activeVals, string(newValidatorPubKey2.Address())) delete(activeVals, string(newValidatorPubKey3.Address())) - waitForBlockWithUpdatedValsAndValidateIt(t, nPeers, activeVals, blocksSubs, states) + waitForBlockWithUpdatedValsAndValidateIt(ctx, t, nPeers, activeVals, blocksSubs, states) } diff --git a/internal/consensus/replay.go b/internal/consensus/replay.go index 563e5ca64..f40389f2b 100644 --- a/internal/consensus/replay.go +++ b/internal/consensus/replay.go @@ -237,10 +237,10 @@ func (h *Handshaker) NBlocks() int { } // TODO: retry the handshake/replay if it fails ? -func (h *Handshaker) Handshake(proxyApp proxy.AppConns) error { +func (h *Handshaker) Handshake(ctx context.Context, proxyApp proxy.AppConns) error { // Handshake is done via ABCI Info on the query conn. - res, err := proxyApp.Query().InfoSync(context.Background(), proxy.RequestInfo) + res, err := proxyApp.Query().InfoSync(ctx, proxy.RequestInfo) if err != nil { return fmt.Errorf("error calling Info: %v", err) } @@ -264,7 +264,7 @@ func (h *Handshaker) Handshake(proxyApp proxy.AppConns) error { } // Replay blocks up to the latest in the blockstore. - _, err = h.ReplayBlocks(h.initialState, appHash, blockHeight, proxyApp) + _, err = h.ReplayBlocks(ctx, h.initialState, appHash, blockHeight, proxyApp) if err != nil { return fmt.Errorf("error on replay: %v", err) } @@ -281,6 +281,7 @@ func (h *Handshaker) Handshake(proxyApp proxy.AppConns) error { // matches the current state. // Returns the final AppHash or an error. func (h *Handshaker) ReplayBlocks( + ctx context.Context, state sm.State, appHash []byte, appBlockHeight int64, @@ -315,7 +316,7 @@ func (h *Handshaker) ReplayBlocks( Validators: nextVals, AppStateBytes: h.genDoc.AppState, } - res, err := proxyApp.Consensus().InitChainSync(context.Background(), req) + res, err := proxyApp.Consensus().InitChainSync(ctx, req) if err != nil { return nil, err } @@ -421,7 +422,7 @@ func (h *Handshaker) ReplayBlocks( if err != nil { return nil, err } - mockApp := newMockProxyApp(h.logger, appHash, abciResponses) + mockApp := newMockProxyApp(ctx, h.logger, appHash, abciResponses) h.logger.Info("Replay last block using mock app") state, err = h.replayBlock(state, storeBlockHeight, mockApp) return state.AppHash, err diff --git a/internal/consensus/replay_file.go b/internal/consensus/replay_file.go index 6f1e64b2a..1de0ffa0e 100644 --- a/internal/consensus/replay_file.go +++ b/internal/consensus/replay_file.go @@ -32,17 +32,18 @@ const ( // replay the wal file func RunReplayFile( + ctx context.Context, logger log.Logger, cfg config.BaseConfig, csConfig *config.ConsensusConfig, console bool, ) error { - consensusState, err := newConsensusStateForReplay(cfg, logger, csConfig) + consensusState, err := newConsensusStateForReplay(ctx, cfg, logger, csConfig) if err != nil { return err } - if err := consensusState.ReplayFile(csConfig.WalFile(), console); err != nil { + if err := consensusState.ReplayFile(ctx, csConfig.WalFile(), console); err != nil { return fmt.Errorf("consensus replay: %w", err) } @@ -50,7 +51,7 @@ func RunReplayFile( } // Replay msgs in file or start the console -func (cs *State) ReplayFile(file string, console bool) error { +func (cs *State) ReplayFile(ctx context.Context, file string, console bool) error { if cs.IsRunning() { return errors.New("cs is already running, cannot replay") @@ -63,7 +64,6 @@ func (cs *State) ReplayFile(file string, console bool) error { // ensure all new step events are regenerated as expected - ctx := context.Background() newStepSub, err := cs.eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: subscriber, Query: types.EventQueryNewRoundStep, @@ -307,6 +307,7 @@ func (pb *playback) replayConsoleLoop() (int, error) { // convenience for replay mode func newConsensusStateForReplay( + ctx context.Context, cfg config.BaseConfig, logger log.Logger, csConfig *config.ConsensusConfig, @@ -339,19 +340,19 @@ func newConsensusStateForReplay( // Create proxyAppConn connection (consensus, mempool, query) clientCreator, _ := proxy.DefaultClientCreator(logger, cfg.ProxyApp, cfg.ABCI, cfg.DBDir()) proxyApp := proxy.NewAppConns(clientCreator, logger, proxy.NopMetrics()) - err = proxyApp.Start() + err = proxyApp.Start(ctx) if err != nil { return nil, fmt.Errorf("starting proxy app conns: %w", err) } eventBus := eventbus.NewDefault(logger) - if err := eventBus.Start(); err != nil { + if err := eventBus.Start(ctx); err != nil { return nil, fmt.Errorf("failed to start event bus: %w", err) } handshaker := NewHandshaker(logger, stateStore, state, blockStore, eventBus, gdoc) - if err = handshaker.Handshake(proxyApp); err != nil { + if err = handshaker.Handshake(ctx, proxyApp); err != nil { return nil, err } diff --git a/internal/consensus/replay_stubs.go b/internal/consensus/replay_stubs.go index 679aba611..8672f8e1e 100644 --- a/internal/consensus/replay_stubs.go +++ b/internal/consensus/replay_stubs.go @@ -55,13 +55,18 @@ func (emptyMempool) CloseWAL() {} // Useful because we don't want to call Commit() twice for the same block on // the real app. -func newMockProxyApp(logger log.Logger, appHash []byte, abciResponses *tmstate.ABCIResponses) proxy.AppConnConsensus { +func newMockProxyApp( + ctx context.Context, + logger log.Logger, + appHash []byte, + abciResponses *tmstate.ABCIResponses, +) proxy.AppConnConsensus { clientCreator := abciclient.NewLocalCreator(&mockProxyApp{ appHash: appHash, abciResponses: abciResponses, }) cli, _ := clientCreator(logger) - err := cli.Start() + err := cli.Start(ctx) if err != nil { panic(err) } diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index f5f5d1633..bf1927742 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -56,7 +56,7 @@ import ( // and which ones we need the wal for - then we'd also be able to only flush the // wal writer when we need to, instead of with every message. -func startNewStateAndWaitForBlock(t *testing.T, consensusReplayConfig *config.Config, +func startNewStateAndWaitForBlock(ctx context.Context, t *testing.T, consensusReplayConfig *config.Config, lastBlockHeight int64, blockDB dbm.DB, stateStore sm.Store) { logger := log.TestingLogger() state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile()) @@ -64,6 +64,7 @@ func startNewStateAndWaitForBlock(t *testing.T, consensusReplayConfig *config.Co privValidator := loadPrivValidator(consensusReplayConfig) blockStore := store.NewBlockStore(dbm.NewMemDB()) cs := newStateWithConfigAndBlockStore( + ctx, logger, consensusReplayConfig, state, @@ -75,7 +76,7 @@ func startNewStateAndWaitForBlock(t *testing.T, consensusReplayConfig *config.Co bytes, _ := os.ReadFile(cs.config.WalFile()) t.Logf("====== WAL: \n\r%X\n", bytes) - err = cs.Start() + err = cs.Start(ctx) require.NoError(t, err) defer func() { if err := cs.Stop(); err != nil { @@ -87,12 +88,12 @@ func startNewStateAndWaitForBlock(t *testing.T, consensusReplayConfig *config.Co // in the WAL itself. Assuming the consensus state is running, replay of any // WAL, including the empty one, should eventually be followed by a new // block, or else something is wrong. - newBlockSub, err := cs.eventBus.SubscribeWithArgs(context.Background(), pubsub.SubscribeArgs{ + newBlockSub, err := cs.eventBus.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: testSubscriber, Query: types.EventQueryNewBlock, }) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + ctx, cancel := context.WithTimeout(ctx, 120*time.Second) defer cancel() _, err = newBlockSub.Next(ctx) if errors.Is(err, context.DeadlineExceeded) { @@ -109,7 +110,7 @@ func sendTxs(ctx context.Context, cs *State) { return default: tx := []byte{byte(i)} - if err := assertMempool(cs.txNotifier).CheckTx(context.Background(), tx, nil, mempool.TxInfo{}); err != nil { + if err := assertMempool(cs.txNotifier).CheckTx(ctx, tx, nil, mempool.TxInfo{}); err != nil { panic(err) } i++ @@ -119,6 +120,9 @@ func sendTxs(ctx context.Context, cs *State) { // TestWALCrash uses crashing WAL to test we can recover from any WAL failure. func TestWALCrash(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + testCases := []struct { name string initFn func(dbm.DB, *State, context.Context) @@ -139,12 +143,12 @@ func TestWALCrash(t *testing.T) { t.Run(tc.name, func(t *testing.T) { consensusReplayConfig, err := ResetConfig(tc.name) require.NoError(t, err) - crashWALandCheckLiveness(t, consensusReplayConfig, tc.initFn, tc.heightToStop) + crashWALandCheckLiveness(ctx, t, consensusReplayConfig, tc.initFn, tc.heightToStop) }) } } -func crashWALandCheckLiveness(t *testing.T, consensusReplayConfig *config.Config, +func crashWALandCheckLiveness(ctx context.Context, t *testing.T, consensusReplayConfig *config.Config, initFn func(dbm.DB, *State, context.Context), heightToStop int64) { walPanicked := make(chan error) crashingWal := &crashingWAL{panicCh: walPanicked, heightToStop: heightToStop} @@ -164,6 +168,7 @@ LOOP: require.NoError(t, err) privValidator := loadPrivValidator(consensusReplayConfig) cs := newStateWithConfigAndBlockStore( + ctx, logger, consensusReplayConfig, state, @@ -173,7 +178,7 @@ LOOP: ) // start sending transactions - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) initFn(stateDB, cs, ctx) // clean up WAL file from the previous iteration @@ -181,7 +186,7 @@ LOOP: os.Remove(walFile) // set crashing WAL - csWal, err := cs.OpenWAL(walFile) + csWal, err := cs.OpenWAL(ctx, walFile) require.NoError(t, err) crashingWal.next = csWal @@ -190,7 +195,7 @@ LOOP: cs.wal = crashingWal // start consensus state - err = cs.Start() + err = cs.Start(ctx) require.NoError(t, err) i++ @@ -200,7 +205,7 @@ LOOP: t.Logf("WAL panicked: %v", err) // make sure we can make blocks after a crash - startNewStateAndWaitForBlock(t, consensusReplayConfig, cs.Height, blockDB, stateStore) + startNewStateAndWaitForBlock(ctx, t, consensusReplayConfig, cs.Height, blockDB, stateStore) // stop consensus state and transactions sender (initFn) cs.Stop() //nolint:errcheck // Logging this error causes failure @@ -286,9 +291,9 @@ func (w *crashingWAL) SearchForEndHeight( return w.next.SearchForEndHeight(height, options) } -func (w *crashingWAL) Start() error { return w.next.Start() } -func (w *crashingWAL) Stop() error { return w.next.Stop() } -func (w *crashingWAL) Wait() { w.next.Wait() } +func (w *crashingWAL) Start(ctx context.Context) error { return w.next.Start(ctx) } +func (w *crashingWAL) Stop() error { return w.next.Stop() } +func (w *crashingWAL) Wait() { w.next.Wait() } //------------------------------------------------------------------------------------------ type simulatorTestSuite struct { @@ -316,7 +321,7 @@ const ( var modes = []uint{0, 1, 2, 3} // This is actually not a test, it's for storing validator change tx data for testHandshakeReplay -func setupSimulator(t *testing.T) *simulatorTestSuite { +func setupSimulator(ctx context.Context, t *testing.T) *simulatorTestSuite { t.Helper() cfg := configSetup(t) @@ -329,6 +334,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { nVals := 4 css, genDoc, cfg, cleanup := randConsensusNetWithPeers( + ctx, cfg, nVals, nPeers, @@ -341,8 +347,8 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { partSize := types.BlockPartSizeBytes - newRoundCh := subscribe(t, css[0].eventBus, types.EventQueryNewRound) - proposalCh := subscribe(t, css[0].eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, css[0].eventBus, types.EventQueryNewRound) + proposalCh := subscribe(ctx, t, css[0].eventBus, types.EventQueryCompleteProposal) vss := make([]*validatorStub, nPeers) for i := 0; i < nPeers; i++ { @@ -351,13 +357,13 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { height, round := css[0].Height, css[0].Round // start the machine - startTestRound(css[0], height, round) + startTestRound(ctx, css[0], height, round) incrementHeight(vss...) ensureNewRound(newRoundCh, height, 0) ensureNewProposal(proposalCh, height, round) rs := css[0].GetRoundState() - signAddVotes(sim.Config, css[0], tmproto.PrecommitType, + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), vss[1:nVals]...) @@ -366,12 +372,12 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { // HEIGHT 2 height++ incrementHeight(vss...) - newValidatorPubKey1, err := css[nVals].privValidator.GetPubKey(context.Background()) + newValidatorPubKey1, err := css[nVals].privValidator.GetPubKey(ctx) require.NoError(t, err) valPubKey1ABCI, err := encoding.PubKeyToProto(newValidatorPubKey1) require.NoError(t, err) newValidatorTx1 := kvstore.MakeValSetChangeTx(valPubKey1ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx1, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, newValidatorTx1, nil, mempool.TxInfo{}) assert.Nil(t, err) propBlock, _ := css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts := propBlock.MakePartSet(partSize) @@ -379,7 +385,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { proposal := types.NewProposal(vss[1].Height, round, -1, blockID) p := proposal.ToProto() - if err := vss[1].SignProposal(context.Background(), cfg.ChainID(), p); err != nil { + if err := vss[1].SignProposal(ctx, cfg.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } proposal.Signature = p.Signature @@ -390,7 +396,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { } ensureNewProposal(proposalCh, height, round) rs = css[0].GetRoundState() - signAddVotes(sim.Config, css[0], tmproto.PrecommitType, + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), vss[1:nVals]...) ensureNewRound(newRoundCh, height+1, 0) @@ -398,12 +404,12 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { // HEIGHT 3 height++ incrementHeight(vss...) - updateValidatorPubKey1, err := css[nVals].privValidator.GetPubKey(context.Background()) + updateValidatorPubKey1, err := css[nVals].privValidator.GetPubKey(ctx) require.NoError(t, err) updatePubKey1ABCI, err := encoding.PubKeyToProto(updateValidatorPubKey1) require.NoError(t, err) updateValidatorTx1 := kvstore.MakeValSetChangeTx(updatePubKey1ABCI, 25) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), updateValidatorTx1, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, updateValidatorTx1, nil, mempool.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) @@ -411,7 +417,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { proposal = types.NewProposal(vss[2].Height, round, -1, blockID) p = proposal.ToProto() - if err := vss[2].SignProposal(context.Background(), cfg.ChainID(), p); err != nil { + if err := vss[2].SignProposal(ctx, cfg.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } proposal.Signature = p.Signature @@ -422,7 +428,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { } ensureNewProposal(proposalCh, height, round) rs = css[0].GetRoundState() - signAddVotes(sim.Config, css[0], tmproto.PrecommitType, + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), vss[1:nVals]...) ensureNewRound(newRoundCh, height+1, 0) @@ -430,19 +436,19 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { // HEIGHT 4 height++ incrementHeight(vss...) - newValidatorPubKey2, err := css[nVals+1].privValidator.GetPubKey(context.Background()) + newValidatorPubKey2, err := css[nVals+1].privValidator.GetPubKey(ctx) require.NoError(t, err) newVal2ABCI, err := encoding.PubKeyToProto(newValidatorPubKey2) require.NoError(t, err) newValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx2, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, newValidatorTx2, nil, mempool.TxInfo{}) assert.Nil(t, err) - newValidatorPubKey3, err := css[nVals+2].privValidator.GetPubKey(context.Background()) + newValidatorPubKey3, err := css[nVals+2].privValidator.GetPubKey(ctx) require.NoError(t, err) newVal3ABCI, err := encoding.PubKeyToProto(newValidatorPubKey3) require.NoError(t, err) newValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, testMinPower) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), newValidatorTx3, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, newValidatorTx3, nil, mempool.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) @@ -453,10 +459,10 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { valIndexFn := func(cssIdx int) int { for i, vs := range newVss { - vsPubKey, err := vs.GetPubKey(context.Background()) + vsPubKey, err := vs.GetPubKey(ctx) require.NoError(t, err) - cssPubKey, err := css[cssIdx].privValidator.GetPubKey(context.Background()) + cssPubKey, err := css[cssIdx].privValidator.GetPubKey(ctx) require.NoError(t, err) if vsPubKey.Equals(cssPubKey) { @@ -470,7 +476,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { proposal = types.NewProposal(vss[3].Height, round, -1, blockID) p = proposal.ToProto() - if err := vss[3].SignProposal(context.Background(), cfg.ChainID(), p); err != nil { + if err := vss[3].SignProposal(ctx, cfg.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } proposal.Signature = p.Signature @@ -482,7 +488,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { ensureNewProposal(proposalCh, height, round) removeValidatorTx2 := kvstore.MakeValSetChangeTx(newVal2ABCI, 0) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), removeValidatorTx2, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, removeValidatorTx2, nil, mempool.TxInfo{}) assert.Nil(t, err) rs = css[0].GetRoundState() @@ -490,7 +496,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { if i == selfIndex { continue } - signAddVotes(sim.Config, css[0], + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), newVss[i]) } @@ -511,7 +517,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { if i == selfIndex { continue } - signAddVotes(sim.Config, css[0], + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), newVss[i]) } @@ -521,7 +527,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { height++ incrementHeight(vss...) removeValidatorTx3 := kvstore.MakeValSetChangeTx(newVal3ABCI, 0) - err = assertMempool(css[0].txNotifier).CheckTx(context.Background(), removeValidatorTx3, nil, mempool.TxInfo{}) + err = assertMempool(css[0].txNotifier).CheckTx(ctx, removeValidatorTx3, nil, mempool.TxInfo{}) assert.Nil(t, err) propBlock, _ = css[0].createProposalBlock() // changeProposer(t, cs1, vs2) propBlockParts = propBlock.MakePartSet(partSize) @@ -533,7 +539,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { selfIndex = valIndexFn(0) proposal = types.NewProposal(vss[1].Height, round, -1, blockID) p = proposal.ToProto() - if err := vss[1].SignProposal(context.Background(), cfg.ChainID(), p); err != nil { + if err := vss[1].SignProposal(ctx, cfg.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } proposal.Signature = p.Signature @@ -548,7 +554,7 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { if i == selfIndex { continue } - signAddVotes(sim.Config, css[0], + signAddVotes(ctx, sim.Config, css[0], tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), newVss[i]) } @@ -569,55 +575,70 @@ func setupSimulator(t *testing.T) *simulatorTestSuite { // Sync from scratch func TestHandshakeReplayAll(t *testing.T) { - sim := setupSimulator(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sim := setupSimulator(ctx, t) for _, m := range modes { - testHandshakeReplay(t, sim, 0, m, false) + testHandshakeReplay(ctx, t, sim, 0, m, false) } for _, m := range modes { - testHandshakeReplay(t, sim, 0, m, true) + testHandshakeReplay(ctx, t, sim, 0, m, true) } } // Sync many, not from scratch func TestHandshakeReplaySome(t *testing.T) { - sim := setupSimulator(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sim := setupSimulator(ctx, t) for _, m := range modes { - testHandshakeReplay(t, sim, 2, m, false) + testHandshakeReplay(ctx, t, sim, 2, m, false) } for _, m := range modes { - testHandshakeReplay(t, sim, 2, m, true) + testHandshakeReplay(ctx, t, sim, 2, m, true) } } // Sync from lagging by one func TestHandshakeReplayOne(t *testing.T) { - sim := setupSimulator(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sim := setupSimulator(ctx, t) for _, m := range modes { - testHandshakeReplay(t, sim, numBlocks-1, m, false) + testHandshakeReplay(ctx, t, sim, numBlocks-1, m, false) } for _, m := range modes { - testHandshakeReplay(t, sim, numBlocks-1, m, true) + testHandshakeReplay(ctx, t, sim, numBlocks-1, m, true) } } // Sync from caught up func TestHandshakeReplayNone(t *testing.T) { - sim := setupSimulator(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sim := setupSimulator(ctx, t) for _, m := range modes { - testHandshakeReplay(t, sim, numBlocks, m, false) + testHandshakeReplay(ctx, t, sim, numBlocks, m, false) } for _, m := range modes { - testHandshakeReplay(t, sim, numBlocks, m, true) + testHandshakeReplay(ctx, t, sim, numBlocks, m, true) } } // Test mockProxyApp should not panic when app return ABCIResponses with some empty ResponseDeliverTx func TestMockProxyApp(t *testing.T) { - sim := setupSimulator(t) // setup config and simulator + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sim := setupSimulator(ctx, t) // setup config and simulator cfg := sim.Config assert.NotNil(t, cfg) @@ -639,7 +660,7 @@ func TestMockProxyApp(t *testing.T) { err = proto.Unmarshal(bytes, loadedAbciRes) require.NoError(t, err) - mock := newMockProxyApp(logger, []byte("mock_hash"), loadedAbciRes) + mock := newMockProxyApp(ctx, logger, []byte("mock_hash"), loadedAbciRes) abciRes := new(tmstate.ABCIResponses) abciRes.DeliverTxs = make([]*abci.ResponseDeliverTx, len(loadedAbciRes.DeliverTxs)) @@ -663,7 +684,7 @@ func TestMockProxyApp(t *testing.T) { mock.SetResponseCallback(proxyCb) someTx := []byte("tx") - _, err = mock.DeliverTxAsync(context.Background(), abci.RequestDeliverTx{Tx: someTx}) + _, err = mock.DeliverTxAsync(ctx, abci.RequestDeliverTx{Tx: someTx}) assert.NoError(t, err) }) assert.True(t, validTxs == 1) @@ -687,12 +708,23 @@ func tempWALWithData(data []byte) string { // Make some blocks. Start a fresh app and apply nBlocks blocks. // Then restart the app and sync it up with the remaining blocks -func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mode uint, testValidatorsChange bool) { +func testHandshakeReplay( + ctx context.Context, + t *testing.T, + sim *simulatorTestSuite, + nBlocks int, + mode uint, + testValidatorsChange bool, +) { var chain []*types.Block var commits []*types.Commit var store *mockBlockStore var stateDB dbm.DB var genesisState sm.State + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + t.Cleanup(cancel) cfg := sim.Config @@ -712,7 +744,7 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod testConfig, err := ResetConfig(fmt.Sprintf("%s_%v_s", t.Name(), mode)) require.NoError(t, err) defer func() { _ = os.RemoveAll(testConfig.RootDir) }() - walBody, err := WALWithNBlocks(t, numBlocks) + walBody, err := WALWithNBlocks(ctx, t, numBlocks) require.NoError(t, err) walFile := tempWALWithData(walBody) cfg.Consensus.SetWalFile(walFile) @@ -722,16 +754,12 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod wal, err := NewWAL(logger, walFile) require.NoError(t, err) - err = wal.Start() + err = wal.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := wal.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); wal.Wait() }) chain, commits, err = makeBlockchainFromWAL(wal) require.NoError(t, err) - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) require.NoError(t, err) stateDB, genesisState, store = stateAndStore(cfg, pubKey, kvstore.ProtocolVersion) @@ -742,7 +770,19 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod state := genesisState.Copy() // run the chain through state.ApplyBlock to build up the tendermint state - state = buildTMStateFromChain(cfg, logger, sim.Mempool, sim.Evpool, stateStore, state, chain, nBlocks, mode, store) + state = buildTMStateFromChain( + ctx, + cfg, + logger, + sim.Mempool, + sim.Evpool, + stateStore, + state, + chain, + nBlocks, + mode, + store, + ) latestAppHash := state.AppHash // make a new client creator @@ -759,7 +799,7 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod stateStore := sm.NewStore(stateDB1) err := stateStore.Save(genesisState) require.NoError(t, err) - buildAppStateFromChain(proxyApp, stateStore, sim.Mempool, sim.Evpool, genesisState, chain, nBlocks, mode, store) + buildAppStateFromChain(ctx, proxyApp, stateStore, sim.Mempool, sim.Evpool, genesisState, chain, nBlocks, mode, store) } // Prune block store if requested @@ -775,17 +815,13 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod genDoc, _ := sm.MakeGenesisDocFromFile(cfg.GenesisFile()) handshaker := NewHandshaker(logger, stateStore, state, store, eventbus.NopEventBus{}, genDoc) proxyApp := proxy.NewAppConns(clientCreator2, logger, proxy.NopMetrics()) - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { t.Fatalf("Error starting proxy app connections: %v", err) } - t.Cleanup(func() { - if err := proxyApp.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); proxyApp.Wait() }) - err := handshaker.Handshake(proxyApp) + err := handshaker.Handshake(ctx, proxyApp) if expectError { require.Error(t, err) return @@ -794,7 +830,7 @@ func testHandshakeReplay(t *testing.T, sim *simulatorTestSuite, nBlocks int, mod } // get the latest app hash from the app - res, err := proxyApp.Query().InfoSync(context.Background(), abci.RequestInfo{Version: ""}) + res, err := proxyApp.Query().InfoSync(ctx, abci.RequestInfo{Version: ""}) if err != nil { t.Fatal(err) } @@ -838,6 +874,7 @@ func applyBlock(stateStore sm.Store, } func buildAppStateFromChain( + ctx context.Context, proxyApp proxy.AppConns, stateStore sm.Store, mempool mempool.Mempool, @@ -846,16 +883,16 @@ func buildAppStateFromChain( chain []*types.Block, nBlocks int, mode uint, - blockStore *mockBlockStore) { + blockStore *mockBlockStore, +) { // start a new app without handshake, play nBlocks blocks - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { panic(err) } - defer proxyApp.Stop() //nolint:errcheck // ignore state.Version.Consensus.App = kvstore.ProtocolVersion // simulate handshake, receive app version validators := types.TM2PB.ValidatorUpdates(state.Validators) - if _, err := proxyApp.Consensus().InitChainSync(context.Background(), abci.RequestInitChain{ + if _, err := proxyApp.Consensus().InitChainSync(ctx, abci.RequestInitChain{ Validators: validators, }); err != nil { panic(err) @@ -887,6 +924,7 @@ func buildAppStateFromChain( } func buildTMStateFromChain( + ctx context.Context, cfg *config.Config, logger log.Logger, mempool mempool.Mempool, @@ -905,14 +943,13 @@ func buildTMStateFromChain( clientCreator := abciclient.NewLocalCreator(kvstoreApp) proxyApp := proxy.NewAppConns(clientCreator, logger, proxy.NopMetrics()) - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { panic(err) } - defer proxyApp.Stop() //nolint:errcheck state.Version.Consensus.App = kvstore.ProtocolVersion // simulate handshake, receive app version validators := types.TM2PB.ValidatorUpdates(state.Validators) - if _, err := proxyApp.Consensus().InitChainSync(context.Background(), abci.RequestInitChain{ + if _, err := proxyApp.Consensus().InitChainSync(ctx, abci.RequestInitChain{ Validators: validators, }); err != nil { panic(err) @@ -949,13 +986,17 @@ func TestHandshakePanicsIfAppReturnsWrongAppHash(t *testing.T) { // - 0x01 // - 0x02 // - 0x03 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := ResetConfig("handshake_test_") require.NoError(t, err) t.Cleanup(func() { os.RemoveAll(cfg.RootDir) }) privVal, err := privval.LoadFilePV(cfg.PrivValidator.KeyFile(), cfg.PrivValidator.StateFile()) require.NoError(t, err) const appVersion = 0x0 - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) require.NoError(t, err) stateDB, state, store := stateAndStore(cfg, pubKey, appVersion) stateStore := sm.NewStore(stateDB) @@ -975,17 +1016,13 @@ func TestHandshakePanicsIfAppReturnsWrongAppHash(t *testing.T) { app := &badApp{numBlocks: 3, allHashesAreWrong: true} clientCreator := abciclient.NewLocalCreator(app) proxyApp := proxy.NewAppConns(clientCreator, logger, proxy.NopMetrics()) - err := proxyApp.Start() + err := proxyApp.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := proxyApp.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); proxyApp.Wait() }) assert.Panics(t, func() { h := NewHandshaker(logger, stateStore, state, store, eventbus.NopEventBus{}, genDoc) - if err = h.Handshake(proxyApp); err != nil { + if err = h.Handshake(ctx, proxyApp); err != nil { t.Log(err) } }) @@ -999,17 +1036,13 @@ func TestHandshakePanicsIfAppReturnsWrongAppHash(t *testing.T) { app := &badApp{numBlocks: 3, onlyLastHashIsWrong: true} clientCreator := abciclient.NewLocalCreator(app) proxyApp := proxy.NewAppConns(clientCreator, logger, proxy.NopMetrics()) - err := proxyApp.Start() + err := proxyApp.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := proxyApp.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); proxyApp.Wait() }) assert.Panics(t, func() { h := NewHandshaker(logger, stateStore, state, store, eventbus.NopEventBus{}, genDoc) - if err = h.Handshake(proxyApp); err != nil { + if err = h.Handshake(ctx, proxyApp); err != nil { t.Log(err) } }) @@ -1237,6 +1270,9 @@ func (bs *mockBlockStore) PruneBlocks(height int64) (uint64, error) { // Test handshake/init chain func TestHandshakeUpdatesValidators(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + val, _ := factory.RandValidator(true, 10) vals := types.NewValidatorSet([]*types.Validator{val}) app := &initChainApp{vals: types.TM2PB.ValidatorUpdates(vals)} @@ -1248,7 +1284,7 @@ func TestHandshakeUpdatesValidators(t *testing.T) { privVal, err := privval.LoadFilePV(cfg.PrivValidator.KeyFile(), cfg.PrivValidator.StateFile()) require.NoError(t, err) - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) require.NoError(t, err) stateDB, state, store := stateAndStore(cfg, pubKey, 0x0) stateStore := sm.NewStore(stateDB) @@ -1262,12 +1298,11 @@ func TestHandshakeUpdatesValidators(t *testing.T) { logger := log.TestingLogger() handshaker := NewHandshaker(logger, stateStore, state, store, eventbus.NopEventBus{}, genDoc) proxyApp := proxy.NewAppConns(clientCreator, logger, proxy.NopMetrics()) - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { t.Fatalf("Error starting proxy app connections: %v", err) } - t.Cleanup(func() { require.NoError(t, proxyApp.Stop()) }) - if err := handshaker.Handshake(proxyApp); err != nil { + if err := handshaker.Handshake(ctx, proxyApp); err != nil { t.Fatalf("Error on abci handshake: %v", err) } // reload the state, check the validator set was updated diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 6cef4ac8f..8cf1a7f9f 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -329,11 +329,11 @@ func (cs *State) LoadCommit(height int64) *types.Commit { // OnStart loads the latest state via the WAL, and starts the timeout and // receive routines. -func (cs *State) OnStart() error { +func (cs *State) OnStart(ctx context.Context) error { // We may set the WAL in testing before calling Start, so only OpenWAL if its // still the nilWAL. if _, ok := cs.wal.(nilWAL); ok { - if err := cs.loadWalFile(); err != nil { + if err := cs.loadWalFile(ctx); err != nil { return err } } @@ -384,13 +384,13 @@ func (cs *State) OnStart() error { cs.Logger.Info("successful WAL repair") // reload WAL file - if err := cs.loadWalFile(); err != nil { + if err := cs.loadWalFile(ctx); err != nil { return err } } } - if err := cs.evsw.Start(); err != nil { + if err := cs.evsw.Start(ctx); err != nil { return err } @@ -399,7 +399,7 @@ func (cs *State) OnStart() error { // NOTE: we will get a build up of garbage go routines // firing on the tockChan until the receiveRoutine is started // to deal with them (by that point, at most one will be valid) - if err := cs.timeoutTicker.Start(); err != nil { + if err := cs.timeoutTicker.Start(ctx); err != nil { return err } @@ -420,8 +420,8 @@ func (cs *State) OnStart() error { // timeoutRoutine: receive requests for timeouts on tickChan and fire timeouts on tockChan // receiveRoutine: serializes processing of proposoals, block parts, votes; coordinates state transitions -func (cs *State) startRoutines(maxSteps int) { - err := cs.timeoutTicker.Start() +func (cs *State) startRoutines(ctx context.Context, maxSteps int) { + err := cs.timeoutTicker.Start(ctx) if err != nil { cs.Logger.Error("failed to start timeout ticker", "err", err) return @@ -431,8 +431,8 @@ func (cs *State) startRoutines(maxSteps int) { } // loadWalFile loads WAL data from file. It overwrites cs.wal. -func (cs *State) loadWalFile() error { - wal, err := cs.OpenWAL(cs.config.WalFile()) +func (cs *State) loadWalFile(ctx context.Context) error { + wal, err := cs.OpenWAL(ctx, cs.config.WalFile()) if err != nil { cs.Logger.Error("failed to load state WAL", "err", err) return err @@ -457,11 +457,15 @@ func (cs *State) OnStop() { close(cs.onStopCh) if err := cs.evsw.Stop(); err != nil { - cs.Logger.Error("failed trying to stop eventSwitch", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + cs.Logger.Error("failed trying to stop eventSwitch", "error", err) + } } if err := cs.timeoutTicker.Stop(); err != nil { - cs.Logger.Error("failed trying to stop timeoutTicket", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + cs.Logger.Error("failed trying to stop timeoutTicket", "error", err) + } } // WAL is stopped in receiveRoutine. } @@ -475,14 +479,14 @@ func (cs *State) Wait() { // OpenWAL opens a file to log all consensus messages and timeouts for // deterministic accountability. -func (cs *State) OpenWAL(walFile string) (WAL, error) { +func (cs *State) OpenWAL(ctx context.Context, walFile string) (WAL, error) { wal, err := NewWAL(cs.Logger.With("wal", walFile), walFile) if err != nil { cs.Logger.Error("failed to open WAL", "file", walFile, "err", err) return nil, err } - if err := wal.Start(); err != nil { + if err := wal.Start(ctx); err != nil { cs.Logger.Error("failed to start WAL", "err", err) return nil, err } @@ -762,7 +766,9 @@ func (cs *State) receiveRoutine(maxSteps int) { // close wal now that we're done writing to it if err := cs.wal.Stop(); err != nil { - cs.Logger.Error("failed trying to stop WAL", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + cs.Logger.Error("failed trying to stop WAL", "error", err) + } } cs.wal.Wait() diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 1d0741b1f..9216577d5 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -56,24 +56,26 @@ x * TestHalt1 - if we see +2/3 precommits after timing out into new round, we sh // ProposeSuite func TestStateProposerSelection0(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config := configSetup(t) - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) height, round := cs1.Height, cs1.Round - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) // Wait for new round so proposer is set. ensureNewRound(newRoundCh, height, round) // Commit a block and ensure proposer for the next height is correct. prop := cs1.GetRoundState().Validators.GetProposer() - pv, err := cs1.privValidator.GetPubKey(context.Background()) + pv, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) address := pv.Address() if !bytes.Equal(prop.Address, address) { @@ -84,13 +86,21 @@ func TestStateProposerSelection0(t *testing.T) { ensureNewProposal(proposalCh, height, round) rs := cs1.GetRoundState() - signAddVotes(config, cs1, tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), vss[1:]...) + signAddVotes( + ctx, + config, + cs1, + tmproto.PrecommitType, + rs.ProposalBlock.Hash(), + rs.ProposalBlockParts.Header(), + vss[1:]..., + ) // Wait for new round so next validator is set. ensureNewRound(newRoundCh, height+1, 0) prop = cs1.GetRoundState().Validators.GetProposer() - pv1, err := vss[1].GetPubKey(context.Background()) + pv1, err := vss[1].GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() if !bytes.Equal(prop.Address, addr) { @@ -102,25 +112,28 @@ func TestStateProposerSelection0(t *testing.T) { func TestStateProposerSelection2(t *testing.T) { config := configSetup(t) - cs1, vss, err := randState(config, log.TestingLogger(), 4) // test needs more work for more than 3 validators + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) // test needs more work for more than 3 validators require.NoError(t, err) height := cs1.Height - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) // this time we jump in at round 2 incrementRound(vss[1:]...) incrementRound(vss[1:]...) var round int32 = 2 - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) // wait for the new round // everyone just votes nil. we get a new proposer each round for i := int32(0); int(i) < len(vss); i++ { prop := cs1.GetRoundState().Validators.GetProposer() - pvk, err := vss[int(i+round)%len(vss)].GetPubKey(context.Background()) + pvk, err := vss[int(i+round)%len(vss)].GetPubKey(ctx) require.NoError(t, err) addr := pvk.Address() correctProposer := addr @@ -132,7 +145,7 @@ func TestStateProposerSelection2(t *testing.T) { } rs := cs1.GetRoundState() - signAddVotes(config, cs1, tmproto.PrecommitType, nil, rs.ProposalBlockParts.Header(), vss[1:]...) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, rs.ProposalBlockParts.Header(), vss[1:]...) ensureNewRound(newRoundCh, height, i+round+1) // wait for the new round event each round incrementRound(vss[1:]...) } @@ -142,16 +155,18 @@ func TestStateProposerSelection2(t *testing.T) { // a non-validator should timeout into the prevote round func TestStateEnterProposeNoPrivValidator(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs, _, err := randState(config, log.TestingLogger(), 1) + cs, _, err := randState(ctx, config, log.TestingLogger(), 1) require.NoError(t, err) cs.SetPrivValidator(nil) height, round := cs.Height, cs.Round // Listen for propose timeout event - timeoutCh := subscribe(t, cs.eventBus, types.EventQueryTimeoutPropose) + timeoutCh := subscribe(ctx, t, cs.eventBus, types.EventQueryTimeoutPropose) - startTestRound(cs, height, round) + startTestRound(ctx, cs, height, round) // if we're not a validator, EnterPropose should timeout ensureNewTimeout(timeoutCh, height, round, cs.config.TimeoutPropose.Nanoseconds()) @@ -164,18 +179,20 @@ func TestStateEnterProposeNoPrivValidator(t *testing.T) { // a validator should not timeout of the prevote round (TODO: unless the block is really big!) func TestStateEnterProposeYesPrivValidator(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs, _, err := randState(config, log.TestingLogger(), 1) + cs, _, err := randState(ctx, config, log.TestingLogger(), 1) require.NoError(t, err) height, round := cs.Height, cs.Round // Listen for propose timeout event - timeoutCh := subscribe(t, cs.eventBus, types.EventQueryTimeoutPropose) - proposalCh := subscribe(t, cs.eventBus, types.EventQueryCompleteProposal) + timeoutCh := subscribe(ctx, t, cs.eventBus, types.EventQueryTimeoutPropose) + proposalCh := subscribe(ctx, t, cs.eventBus, types.EventQueryCompleteProposal) cs.enterNewRound(height, round) - cs.startRoutines(3) + cs.startRoutines(ctx, 3) ensureNewProposal(proposalCh, height, round) @@ -197,16 +214,18 @@ func TestStateEnterProposeYesPrivValidator(t *testing.T) { func TestStateBadProposal(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 2) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) height, round := cs1.Height, cs1.Round vs2 := vss[1] partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - voteCh := subscribe(t, cs1.eventBus, types.EventQueryVote) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + voteCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryVote) propBlock, _ := cs1.createProposalBlock() // changeProposer(t, cs1, vs2) @@ -225,7 +244,7 @@ func TestStateBadProposal(t *testing.T) { blockID := types.BlockID{Hash: propBlock.Hash(), PartSetHeader: propBlockParts.Header()} proposal := types.NewProposal(vs2.Height, round, -1, blockID) p := proposal.ToProto() - if err := vs2.SignProposal(context.Background(), config.ChainID(), p); err != nil { + if err := vs2.SignProposal(ctx, config.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } @@ -237,29 +256,31 @@ func TestStateBadProposal(t *testing.T) { } // start the machine - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) // wait for proposal ensureProposal(proposalCh, height, round, blockID) // wait for prevote ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + validatePrevote(ctx, t, cs1, round, vss[0], nil) // add bad prevote from vs2 and wait for it - signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) // wait for precommit ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) - signAddVotes(config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) } func TestStateOversizedBlock(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 2) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) cs1.state.ConsensusParams.Block.MaxBytes = 2000 height, round := cs1.Height, cs1.Round @@ -267,8 +288,8 @@ func TestStateOversizedBlock(t *testing.T) { partSize := types.BlockPartSizeBytes - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - voteCh := subscribe(t, cs1.eventBus, types.EventQueryVote) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + voteCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryVote) propBlock, _ := cs1.createProposalBlock() propBlock.Data.Txs = []types.Tx{tmrand.Bytes(2001)} @@ -282,7 +303,7 @@ func TestStateOversizedBlock(t *testing.T) { blockID := types.BlockID{Hash: propBlock.Hash(), PartSetHeader: propBlockParts.Header()} proposal := types.NewProposal(height, round, -1, blockID) p := proposal.ToProto() - if err := vs2.SignProposal(context.Background(), config.ChainID(), p); err != nil { + if err := vs2.SignProposal(ctx, config.ChainID(), p); err != nil { t.Fatal("failed to sign bad proposal", err) } proposal.Signature = p.Signature @@ -298,7 +319,7 @@ func TestStateOversizedBlock(t *testing.T) { } // start the machine - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) t.Log("Block Sizes", "Limit", cs1.state.ConsensusParams.Block.MaxBytes, "Current", totalBytes) @@ -309,12 +330,12 @@ func TestStateOversizedBlock(t *testing.T) { // and then should send nil prevote and precommit regardless of whether other validators prevote and // precommit on it ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) - signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) + validatePrevote(ctx, t, cs1, round, vss[0], nil) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) - signAddVotes(config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) } //---------------------------------------------------------------------------------------------------- @@ -324,8 +345,10 @@ func TestStateOversizedBlock(t *testing.T) { func TestStateFullRound1(t *testing.T) { config := configSetup(t) logger := log.TestingLogger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs, vss, err := randState(config, logger, 1) + cs, vss, err := randState(ctx, config, logger, 1) require.NoError(t, err) height, round := cs.Height, cs.Round @@ -337,16 +360,16 @@ func TestStateFullRound1(t *testing.T) { eventBus := eventbus.NewDefault(logger.With("module", "events")) cs.SetEventBus(eventBus) - if err := eventBus.Start(); err != nil { + if err := eventBus.Start(ctx); err != nil { t.Error(err) } - voteCh := subscribe(t, cs.eventBus, types.EventQueryVote) - propCh := subscribe(t, cs.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(t, cs.eventBus, types.EventQueryNewRound) + voteCh := subscribe(ctx, t, cs.eventBus, types.EventQueryVote) + propCh := subscribe(ctx, t, cs.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewRound) // Maybe it would be better to call explicitly startRoutines(4) - startTestRound(cs, height, round) + startTestRound(ctx, cs, height, round) ensureNewRound(newRoundCh, height, round) @@ -354,51 +377,55 @@ func TestStateFullRound1(t *testing.T) { propBlockHash := cs.GetRoundState().ProposalBlock.Hash() ensurePrevote(voteCh, height, round) // wait for prevote - validatePrevote(t, cs, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs, round, vss[0], propBlockHash) ensurePrecommit(voteCh, height, round) // wait for precommit // we're going to roll right into new height ensureNewRound(newRoundCh, height+1, 0) - validateLastPrecommit(t, cs, vss[0], propBlockHash) + validateLastPrecommit(ctx, t, cs, vss[0], propBlockHash) } // nil is proposed, so prevote and precommit nil func TestStateFullRoundNil(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs, vss, err := randState(config, log.TestingLogger(), 1) + cs, vss, err := randState(ctx, config, log.TestingLogger(), 1) require.NoError(t, err) height, round := cs.Height, cs.Round - voteCh := subscribe(t, cs.eventBus, types.EventQueryVote) + voteCh := subscribe(ctx, t, cs.eventBus, types.EventQueryVote) cs.enterPrevote(height, round) - cs.startRoutines(4) + cs.startRoutines(ctx, 4) ensurePrevote(voteCh, height, round) // prevote ensurePrecommit(voteCh, height, round) // precommit // should prevote and precommit nil - validatePrevoteAndPrecommit(t, cs, round, -1, vss[0], nil, nil) + validatePrevoteAndPrecommit(ctx, t, cs, round, -1, vss[0], nil, nil) } // run through propose, prevote, precommit commit with two validators // where the first validator has to wait for votes from the second func TestStateFullRound2(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 2) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) vs2 := vss[1] height, round := cs1.Height, cs1.Round - voteCh := subscribe(t, cs1.eventBus, types.EventQueryVote) - newBlockCh := subscribe(t, cs1.eventBus, types.EventQueryNewBlock) + voteCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryVote) + newBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewBlock) // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensurePrevote(voteCh, height, round) // prevote @@ -407,17 +434,17 @@ func TestStateFullRound2(t *testing.T) { propBlockHash, propPartSetHeader := rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header() // prevote arrives from vs2: - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propPartSetHeader, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propPartSetHeader, vs2) ensurePrevote(voteCh, height, round) // prevote ensurePrecommit(voteCh, height, round) // precommit // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, 0, 0, vss[0], propBlockHash, propBlockHash) + validatePrecommit(ctx, t, cs1, 0, 0, vss[0], propBlockHash, propBlockHash) // we should be stuck in limbo waiting for more precommits // precommit arrives from vs2: - signAddVotes(config, cs1, tmproto.PrecommitType, propBlockHash, propPartSetHeader, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlockHash, propPartSetHeader, vs2) ensurePrecommit(voteCh, height, round) // wait to finish commit, propose in next height @@ -431,19 +458,21 @@ func TestStateFullRound2(t *testing.T) { // two vals take turns proposing. val1 locks on first one, precommits nil on everything else func TestStateLockNoPOL(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 2) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) vs2 := vss[1] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - voteCh := subscribe(t, cs1.eventBus, types.EventQueryVote) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + voteCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryVote) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) /* Round1 (cs1, B) // B B // B B2 @@ -451,7 +480,7 @@ func TestStateLockNoPOL(t *testing.T) { // start round and wait for prevote cs1.enterNewRound(height, round) - cs1.startRoutines(0) + cs1.startRoutines(ctx, 0) ensureNewRound(newRoundCh, height, round) @@ -464,19 +493,19 @@ func TestStateLockNoPOL(t *testing.T) { // we should now be stuck in limbo forever, waiting for more prevotes // prevote arrives from vs2: - signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, thePartSetHeader, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, theBlockHash, thePartSetHeader, vs2) ensurePrevote(voteCh, height, round) // prevote ensurePrecommit(voteCh, height, round) // precommit // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], theBlockHash, theBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], theBlockHash, theBlockHash) // we should now be stuck in limbo forever, waiting for more precommits // lets add one for a different block hash := make([]byte, len(theBlockHash)) copy(hash, theBlockHash) hash[0] = (hash[0] + 1) % 255 - signAddVotes(config, cs1, tmproto.PrecommitType, hash, thePartSetHeader, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, hash, thePartSetHeader, vs2) ensurePrecommit(voteCh, height, round) // precommit // (note we're entering precommit for a second time this round) @@ -506,10 +535,10 @@ func TestStateLockNoPOL(t *testing.T) { // wait to finish prevote ensurePrevote(voteCh, height, round) // we should have prevoted our locked block - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + validatePrevote(ctx, t, cs1, round, vss[0], rs.LockedBlock.Hash()) // add a conflicting prevote from the other validator - signAddVotes(config, cs1, tmproto.PrevoteType, hash, rs.LockedBlock.MakePartSet(partSize).Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, hash, rs.LockedBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) // now we're going to enter prevote again, but with invalid args @@ -519,10 +548,10 @@ func TestStateLockNoPOL(t *testing.T) { ensurePrecommit(voteCh, height, round) // precommit // the proposed block should still be locked and our precommit added // we should precommit nil and be locked on the proposal - validatePrecommit(t, cs1, round, 0, vss[0], nil, theBlockHash) + validatePrecommit(ctx, t, cs1, round, 0, vss[0], nil, theBlockHash) // add conflicting precommit from vs2 - signAddVotes(config, cs1, tmproto.PrecommitType, hash, rs.LockedBlock.MakePartSet(partSize).Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, hash, rs.LockedBlock.MakePartSet(partSize).Header(), vs2) ensurePrecommit(voteCh, height, round) // (note we're entering precommit for a second time this round, but with invalid args @@ -550,17 +579,18 @@ func TestStateLockNoPOL(t *testing.T) { } ensurePrevote(voteCh, height, round) // prevote - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + validatePrevote(ctx, t, cs1, round, vss[0], rs.LockedBlock.Hash()) - signAddVotes(config, cs1, tmproto.PrevoteType, hash, rs.ProposalBlock.MakePartSet(partSize).Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, hash, rs.ProposalBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Prevote(round).Nanoseconds()) ensurePrecommit(voteCh, height, round) // precommit - validatePrecommit(t, cs1, round, 0, vss[0], nil, theBlockHash) // precommit nil but be locked on proposal + validatePrecommit(ctx, t, cs1, round, 0, vss[0], nil, theBlockHash) // precommit nil but be locked on proposal signAddVotes( + ctx, config, cs1, tmproto.PrecommitType, @@ -571,10 +601,11 @@ func TestStateLockNoPOL(t *testing.T) { ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Precommit(round).Nanoseconds()) - cs2, _, err := randState(config, log.TestingLogger(), 2) // needed so generated block is different than locked block + // needed so generated block is different than locked block + cs2, _, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) // before we time out into new round, set next proposal block - prop, propBlock := decideProposal(cs2, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs2, vs2, vs2.Height, vs2.Round+1) if prop == nil || propBlock == nil { t.Fatal("Failed to create proposal block with vs2") } @@ -597,17 +628,18 @@ func TestStateLockNoPOL(t *testing.T) { ensureNewProposal(proposalCh, height, round) ensurePrevote(voteCh, height, round) // prevote // prevote for locked block (not proposal) - validatePrevote(t, cs1, 3, vss[0], cs1.LockedBlock.Hash()) + validatePrevote(ctx, t, cs1, 3, vss[0], cs1.LockedBlock.Hash()) // prevote for proposed block - signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Prevote(round).Nanoseconds()) ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, 0, vss[0], nil, theBlockHash) // precommit nil but locked on proposal + validatePrecommit(ctx, t, cs1, round, 0, vss[0], nil, theBlockHash) // precommit nil but locked on proposal signAddVotes( + ctx, config, cs1, tmproto.PrecommitType, @@ -625,21 +657,24 @@ func TestStateLockPOLRelock(t *testing.T) { config := configSetup(t) logger := log.TestingLogger() - cs1, vss, err := randState(config, logger, 4) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cs1, vss, err := randState(ctx, config, logger, 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - newBlockCh := subscribe(t, cs1.eventBus, types.EventQueryNewBlockHeader) + voteCh := subscribeToVoter(ctx, t, cs1, addr) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + newBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewBlockHeader) // everything done from perspective of cs1 @@ -650,7 +685,7 @@ func TestStateLockPOLRelock(t *testing.T) { */ // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -660,20 +695,20 @@ func TestStateLockPOLRelock(t *testing.T) { ensurePrevote(voteCh, height, round) // prevote - signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // our precommit // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], theBlockHash, theBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], theBlockHash, theBlockHash) // add precommits from the rest - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) // before we timeout to the new round set the new proposal - cs2, err := newState(logger, cs1.state, vs2, kvstore.NewApplication()) + cs2, err := newState(ctx, logger, cs1.state, vs2, kvstore.NewApplication()) require.NoError(t, err) - prop, propBlock := decideProposal(cs2, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs2, vs2, vs2.Height, vs2.Round+1) if prop == nil || propBlock == nil { t.Fatal("Failed to create proposal block with vs2") } @@ -707,17 +742,17 @@ func TestStateLockPOLRelock(t *testing.T) { // go to prevote, node should prevote for locked block (not the new proposal) - this is relocking ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], theBlockHash) // now lets add prevotes from everyone else for the new block - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // we should have unlocked and locked on the new block, sending a precommit for this new block - validatePrecommit(t, cs1, round, round, vss[0], propBlockHash, propBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], propBlockHash, propBlockHash) // more prevote creating a majority on the new block and this is then committed - signAddVotes(config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3) ensureNewBlockHeader(newBlockCh, height, propBlockHash) ensureNewRound(newRoundCh, height+1, 0) @@ -726,22 +761,24 @@ func TestStateLockPOLRelock(t *testing.T) { // 4 vals, one precommits, other 3 polka at next round, so we unlock and precomit the polka func TestStateLockPOLUnlock(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - unlockCh := subscribe(t, cs1.eventBus, types.EventQueryUnlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + unlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryUnlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // everything done from perspective of cs1 @@ -751,7 +788,7 @@ func TestStateLockPOLUnlock(t *testing.T) { */ // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -760,20 +797,20 @@ func TestStateLockPOLUnlock(t *testing.T) { theBlockParts := rs.ProposalBlockParts.Header() ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], theBlockHash) - signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], theBlockHash, theBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], theBlockHash, theBlockHash) // add precommits from the rest - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs4) - signAddVotes(config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) // before we time out into new round, set next proposal block - prop, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round+1) propBlockParts := propBlock.MakePartSet(partSize) // timeout to new round @@ -799,9 +836,9 @@ func TestStateLockPOLUnlock(t *testing.T) { // go to prevote, prevote for locked block (not proposal) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], lockedBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], lockedBlockHash) // now lets add prevotes from everyone else for nil (a polka!) - signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) // the polka makes us unlock and precommit nil ensureNewUnlock(unlockCh, height, round) @@ -809,9 +846,9 @@ func TestStateLockPOLUnlock(t *testing.T) { // we should have unlocked and committed nil // NOTE: since we don't relock on nil, the lock round is -1 - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3) ensureNewRound(newRoundCh, height, round+1) } @@ -822,21 +859,23 @@ func TestStateLockPOLUnlock(t *testing.T) { func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { config := configSetup(t) logger := log.TestingLogger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, logger, 4) + cs1, vss, err := randState(ctx, config, logger, 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) + voteCh := subscribeToVoter(ctx, t, cs1, addr) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) // everything done from perspective of cs1 /* @@ -844,7 +883,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { */ // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -854,19 +893,19 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { ensurePrevote(voteCh, height, round) // prevote - signAddVotes(config, cs1, tmproto.PrevoteType, firstBlockHash, firstBlockParts, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, firstBlockHash, firstBlockParts, vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // our precommit // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], firstBlockHash, firstBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], firstBlockHash, firstBlockHash) // add precommits from the rest - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) // before we timeout to the new round set the new proposal - cs2, err := newState(logger, cs1.state, vs2, kvstore.NewApplication()) + cs2, err := newState(ctx, logger, cs1.state, vs2, kvstore.NewApplication()) require.NoError(t, err) - prop, propBlock := decideProposal(cs2, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs2, vs2, vs2.Height, vs2.Round+1) if prop == nil || propBlock == nil { t.Fatal("Failed to create proposal block with vs2") } @@ -892,26 +931,26 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { // go to prevote, node should prevote for locked block (not the new proposal) - this is relocking ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], firstBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], firstBlockHash) // now lets add prevotes from everyone else for the new block - signAddVotes(config, cs1, tmproto.PrevoteType, secondBlockHash, secondBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, secondBlockHash, secondBlockParts.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // we should have unlocked and locked on the new block, sending a precommit for this new block - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) if err := cs1.SetProposalAndBlock(prop, propBlock, secondBlockParts, "some peer"); err != nil { t.Fatal(err) } // more prevote creating a majority on the new block and this is then committed - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) // before we timeout to the new round set the new proposal - cs3, err := newState(logger, cs1.state, vs3, kvstore.NewApplication()) + cs3, err := newState(ctx, logger, cs1.state, vs3, kvstore.NewApplication()) require.NoError(t, err) - prop, propBlock = decideProposal(cs3, vs3, vs3.Height, vs3.Round+1) + prop, propBlock = decideProposal(ctx, cs3, vs3, vs3.Height, vs3.Round+1) if prop == nil || propBlock == nil { t.Fatal("Failed to create proposal block with vs2") } @@ -938,13 +977,13 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { ensurePrevote(voteCh, height, round) // we are no longer locked to the first block so we should be able to prevote - validatePrevote(t, cs1, round, vss[0], thirdPropBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], thirdPropBlockHash) - signAddVotes(config, cs1, tmproto.PrevoteType, thirdPropBlockHash, thirdPropBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, thirdPropBlockHash, thirdPropBlockParts.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // we have a majority, now vs1 can change lock to the third block - validatePrecommit(t, cs1, round, round, vss[0], thirdPropBlockHash, thirdPropBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], thirdPropBlockHash, thirdPropBlockHash) } // 4 vals @@ -953,25 +992,27 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { // then we see the polka from round 1 but shouldn't unlock func TestStateLockPOLSafety1(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, cs1.Height, round) + startTestRound(ctx, cs1, cs1.Height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -979,17 +1020,17 @@ func TestStateLockPOLSafety1(t *testing.T) { propBlock := rs.ProposalBlock ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlock.Hash()) + validatePrevote(ctx, t, cs1, round, vss[0], propBlock.Hash()) // the others sign a polka but we don't see it - prevotes := signVotes(config, tmproto.PrevoteType, + prevotes := signVotes(ctx, config, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2, vs3, vs4) t.Logf("old prop hash %v", fmt.Sprintf("%X", propBlock.Hash())) // we do see them precommit nil - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) // cs1 precommit nil ensurePrecommit(voteCh, height, round) @@ -997,7 +1038,7 @@ func TestStateLockPOLSafety1(t *testing.T) { t.Log("### ONTO ROUND 1") - prop, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round+1) propBlockHash := propBlock.Hash() propBlockParts := propBlock.MakePartSet(partSize) @@ -1026,16 +1067,16 @@ func TestStateLockPOLSafety1(t *testing.T) { // go to prevote, prevote for proposal block ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash) // now we see the others prevote for it, so we should lock on it - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // we should have precommitted - validatePrecommit(t, cs1, round, round, vss[0], propBlockHash, propBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], propBlockHash, propBlockHash) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Precommit(round).Nanoseconds()) @@ -1055,9 +1096,9 @@ func TestStateLockPOLSafety1(t *testing.T) { // finish prevote ensurePrevote(voteCh, height, round) // we should prevote what we're locked on - validatePrevote(t, cs1, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash) - newStepCh := subscribe(t, cs1.eventBus, types.EventQueryNewRoundStep) + newStepCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRoundStep) // before prevotes from the previous round are added // add prevotes from the earlier round @@ -1077,35 +1118,37 @@ func TestStateLockPOLSafety1(t *testing.T) { // dont see P0, lock on P1 at R1, dont unlock using P0 at R2 func TestStateLockPOLSafety2(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - unlockCh := subscribe(t, cs1.eventBus, types.EventQueryUnlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + unlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryUnlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // the block for R0: gets polkad but we miss it // (even though we signed it, shhh) - _, propBlock0 := decideProposal(cs1, vss[0], height, round) + _, propBlock0 := decideProposal(ctx, cs1, vss[0], height, round) propBlockHash0 := propBlock0.Hash() propBlockParts0 := propBlock0.MakePartSet(partSize) propBlockID0 := types.BlockID{Hash: propBlockHash0, PartSetHeader: propBlockParts0.Header()} // the others sign a polka but we don't see it - prevotes := signVotes(config, tmproto.PrevoteType, propBlockHash0, propBlockParts0.Header(), vs2, vs3, vs4) + prevotes := signVotes(ctx, config, tmproto.PrevoteType, propBlockHash0, propBlockParts0.Header(), vs2, vs3, vs4) // the block for round 1 - prop1, propBlock1 := decideProposal(cs1, vs2, vs2.Height, vs2.Round+1) + prop1, propBlock1 := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round+1) propBlockHash1 := propBlock1.Hash() propBlockParts1 := propBlock1.MakePartSet(partSize) @@ -1114,7 +1157,7 @@ func TestStateLockPOLSafety2(t *testing.T) { round++ // moving to the next round t.Log("### ONTO Round 1") // jump in at round 1 - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) if err := cs1.SetProposalAndBlock(prop1, propBlock1, propBlockParts1, "some peer"); err != nil { @@ -1123,17 +1166,17 @@ func TestStateLockPOLSafety2(t *testing.T) { ensureNewProposal(proposalCh, height, round) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash1) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash1) - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash1, propBlockParts1.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash1, propBlockParts1.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], propBlockHash1, propBlockHash1) + validatePrecommit(ctx, t, cs1, round, round, vss[0], propBlockHash1, propBlockHash1) // add precommits from the rest - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs4) - signAddVotes(config, cs1, tmproto.PrecommitType, propBlockHash1, propBlockParts1.Header(), vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlockHash1, propBlockParts1.Header(), vs3) incrementRound(vs2, vs3, vs4) @@ -1144,7 +1187,7 @@ func TestStateLockPOLSafety2(t *testing.T) { // in round 2 we see the polkad block from round 0 newProp := types.NewProposal(height, round, 0, propBlockID0) p := newProp.ToProto() - if err := vs3.SignProposal(context.Background(), config.ChainID(), p); err != nil { + if err := vs3.SignProposal(ctx, config.ChainID(), p); err != nil { t.Fatal(err) } @@ -1166,7 +1209,7 @@ func TestStateLockPOLSafety2(t *testing.T) { ensureNoNewUnlock(unlockCh) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash1) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash1) } @@ -1177,26 +1220,28 @@ func TestStateLockPOLSafety2(t *testing.T) { // P0 proposes B0 at R3. func TestProposeValidBlock(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - unlockCh := subscribe(t, cs1.eventBus, types.EventQueryUnlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + unlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryUnlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, cs1.Height, round) + startTestRound(ctx, cs1, cs1.Height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -1205,16 +1250,18 @@ func TestProposeValidBlock(t *testing.T) { propBlockHash := propBlock.Hash() ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash) // the others sign a polka - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlock.MakePartSet(partSize).Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, + propBlockHash, propBlock.MakePartSet(partSize).Header(), vs2, + vs3, vs4) ensurePrecommit(voteCh, height, round) // we should have precommitted - validatePrecommit(t, cs1, round, round, vss[0], propBlockHash, propBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], propBlockHash, propBlockHash) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Precommit(round).Nanoseconds()) @@ -1229,20 +1276,20 @@ func TestProposeValidBlock(t *testing.T) { ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash) - signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) ensureNewUnlock(unlockCh, height, round) ensurePrecommit(voteCh, height, round) // we should have precommitted - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) incrementRound(vs2, vs3, vs4) incrementRound(vs2, vs3, vs4) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) round += 2 // moving to the next round @@ -1270,25 +1317,27 @@ func TestProposeValidBlock(t *testing.T) { // P0 miss to lock B but set valid block to B after receiving delayed prevote. func TestSetValidBlockOnDelayedPrevote(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - validBlockCh := subscribe(t, cs1.eventBus, types.EventQueryValidBlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + validBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryValidBlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, cs1.Height, round) + startTestRound(ctx, cs1, cs1.Height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -1298,19 +1347,19 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { propBlockParts := propBlock.MakePartSet(partSize) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], propBlockHash) // vs2 send prevote for propBlock - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2) // vs3 send prevote nil - signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs3) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Prevote(round).Nanoseconds()) ensurePrecommit(voteCh, height, round) // we should have precommitted - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) rs = cs1.GetRoundState() @@ -1319,7 +1368,7 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { assert.True(t, rs.ValidRound == -1) // vs2 send (delayed) prevote for propBlock - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs4) ensureNewValidBlock(validBlockCh, height, round) @@ -1335,47 +1384,49 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { // receiving delayed Block Proposal. func TestSetValidBlockOnDelayedProposal(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - validBlockCh := subscribe(t, cs1.eventBus, types.EventQueryValidBlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + validBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryValidBlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) + voteCh := subscribeToVoter(ctx, t, cs1, addr) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) round++ // move to round in which P0 is not proposer incrementRound(vs2, vs3, vs4) - startTestRound(cs1, cs1.Height, round) + startTestRound(ctx, cs1, cs1.Height, round) ensureNewRound(newRoundCh, height, round) ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + validatePrevote(ctx, t, cs1, round, vss[0], nil) - prop, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round+1) + prop, propBlock := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round+1) propBlockHash := propBlock.Hash() propBlockParts := propBlock.MakePartSet(partSize) // vs2, vs3 and vs4 send prevote for propBlock - signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) ensureNewValidBlock(validBlockCh, height, round) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Prevote(round).Nanoseconds()) ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { t.Fatal(err) @@ -1395,19 +1446,22 @@ func TestSetValidBlockOnDelayedProposal(t *testing.T) { func TestWaitingTimeoutOnNilPolka(t *testing.T) { config := configSetup(t) - cs1, vss, err := randState(config, log.TestingLogger(), 4) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) // start round - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Precommit(round).Nanoseconds()) ensureNewRound(newRoundCh, height, round+1) @@ -1418,27 +1472,29 @@ func TestWaitingTimeoutOnNilPolka(t *testing.T) { // P0 waits for timeoutPropose in the next round before entering prevote func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensurePrevote(voteCh, height, round) incrementRound(vss[1:]...) - signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) round++ // moving to the next round ensureNewRound(newRoundCh, height, round) @@ -1449,7 +1505,7 @@ func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Propose(round).Nanoseconds()) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + validatePrevote(ctx, t, cs1, round, vss[0], nil) } // 4 vals, 3 Precommits for nil from the higher round. @@ -1457,33 +1513,35 @@ func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { // P0 jump to higher round, precommit and start precommit wait func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensurePrevote(voteCh, height, round) incrementRound(vss[1:]...) - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2, vs3, vs4) round++ // moving to the next round ensureNewRound(newRoundCh, height, round) ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, -1, vss[0], nil, nil) + validatePrecommit(ctx, t, cs1, round, -1, vss[0], nil, nil) ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Precommit(round).Nanoseconds()) @@ -1496,38 +1554,42 @@ func TestRoundSkipOnNilPolkaFromHigherRound(t *testing.T) { // P0 wait for timeoutPropose to expire before sending prevote. func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, int32(1) - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round in which PO is not proposer - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) incrementRound(vss[1:]...) - signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + validatePrevote(ctx, t, cs1, round, vss[0], nil) } // What we want: // P0 emit NewValidBlock event upon receiving 2/3+ Precommit for B but hasn't received block B yet func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, int32(1) @@ -1536,19 +1598,19 @@ func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { partSize := types.BlockPartSizeBytes - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - validBlockCh := subscribe(t, cs1.eventBus, types.EventQueryValidBlock) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + validBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryValidBlock) - _, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round) + _, propBlock := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round) propBlockHash := propBlock.Hash() propBlockParts := propBlock.MakePartSet(partSize) // start round in which PO is not proposer - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) // vs2, vs3 and vs4 send precommit for propBlock - signAddVotes(config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) ensureNewValidBlock(validBlockCh, height, round) rs := cs1.GetRoundState() @@ -1563,28 +1625,30 @@ func TestEmitNewValidBlockEventOnCommitWithoutBlock(t *testing.T) { // After receiving block, it executes block and moves to the next height. func TestCommitFromPreviousRound(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, int32(1) partSize := types.BlockPartSizeBytes - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - validBlockCh := subscribe(t, cs1.eventBus, types.EventQueryValidBlock) - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + validBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryValidBlock) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - prop, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round) + prop, propBlock := decideProposal(ctx, cs1, vs2, vs2.Height, vs2.Round) propBlockHash := propBlock.Hash() propBlockParts := propBlock.MakePartSet(partSize) // start round in which PO is not proposer - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) // vs2, vs3 and vs4 send precommit for propBlock for the previous round - signAddVotes(config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) ensureNewValidBlock(validBlockCh, height, round) @@ -1619,28 +1683,30 @@ func (n *fakeTxNotifier) Notify() { // start of the next round func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config.Consensus.SkipTimeoutCommit = false - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) cs1.txNotifier = &fakeTxNotifier{ch: make(chan struct{})} vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutProposeCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutPropose) - precommitTimeoutCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutProposeCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + precommitTimeoutCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - newBlockHeader := subscribe(t, cs1.eventBus, types.EventQueryNewBlockHeader) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + newBlockHeader := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewBlockHeader) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -1649,17 +1715,17 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { theBlockParts := rs.ProposalBlockParts.Header() ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], theBlockHash) - signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], theBlockHash, theBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], theBlockHash, theBlockHash) // add precommits - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) - signAddVotes(config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) // wait till timeout occurs ensurePrecommitTimeout(precommitTimeoutCh) @@ -1667,7 +1733,7 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { ensureNewRound(newRoundCh, height, round+1) // majority is now reached - signAddVotes(config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs4) ensureNewBlockHeader(newBlockHeader, height, theBlockHash) @@ -1683,9 +1749,11 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() config.Consensus.SkipTimeoutCommit = false - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] @@ -1693,17 +1761,17 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - newBlockHeader := subscribe(t, cs1.eventBus, types.EventQueryNewBlockHeader) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + newBlockHeader := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewBlockHeader) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -1712,21 +1780,21 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { theBlockParts := rs.ProposalBlockParts.Header() ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + validatePrevote(ctx, t, cs1, round, vss[0], theBlockHash) - signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) - validatePrecommit(t, cs1, round, round, vss[0], theBlockHash, theBlockHash) + validatePrecommit(ctx, t, cs1, round, round, vss[0], theBlockHash, theBlockHash) // add precommits - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) - signAddVotes(config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) - signAddVotes(config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, theBlockHash, theBlockParts, vs4) ensureNewBlockHeader(newBlockHeader, height, theBlockHash) - prop, propBlock := decideProposal(cs1, vs2, height+1, 0) + prop, propBlock := decideProposal(ctx, cs1, vs2, height+1, 0) propBlockParts := propBlock.MakePartSet(partSize) if err := cs1.SetProposalAndBlock(prop, propBlock, propBlockParts, "some peer"); err != nil { @@ -1751,10 +1819,10 @@ func TestStateSlashingPrevotes(t *testing.T) { vs2 := vss[1] - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - voteCh := subscribeToVoter(t, cs1, cs1.privValidator.GetAddress()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + voteCh := subscribeToVoter(ctx, t, cs1, cs1.privValidator.GetAddress()) // start round and wait for propose and prevote startTestRound(cs1, cs1.Height, 0) @@ -1786,10 +1854,10 @@ func TestStateSlashingPrecommits(t *testing.T) { vs2 := vss[1] - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - voteCh := subscribeToVoter(t, cs1, cs1.privValidator.GetAddress()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + voteCh := subscribeToVoter(ctx, t, cs1, cs1.privValidator.GetAddress()) // start round and wait for propose and prevote startTestRound(cs1, cs1.Height, 0) @@ -1828,24 +1896,26 @@ func TestStateSlashingPrecommits(t *testing.T) { // we receive a final precommit after going into next round, but others might have gone to commit already! func TestStateHalt1(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs1, vss, err := randState(config, log.TestingLogger(), 4) + cs1, vss, err := randState(ctx, config, log.TestingLogger(), 4) require.NoError(t, err) vs2, vs3, vs4 := vss[1], vss[2], vss[3] height, round := cs1.Height, cs1.Round partSize := types.BlockPartSizeBytes - proposalCh := subscribe(t, cs1.eventBus, types.EventQueryCompleteProposal) - timeoutWaitCh := subscribe(t, cs1.eventBus, types.EventQueryTimeoutWait) - newRoundCh := subscribe(t, cs1.eventBus, types.EventQueryNewRound) - newBlockCh := subscribe(t, cs1.eventBus, types.EventQueryNewBlock) - pv1, err := cs1.privValidator.GetPubKey(context.Background()) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + timeoutWaitCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutWait) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + newBlockCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewBlock) + pv1, err := cs1.privValidator.GetPubKey(ctx) require.NoError(t, err) addr := pv1.Address() - voteCh := subscribeToVoter(t, cs1, addr) + voteCh := subscribeToVoter(ctx, t, cs1, addr) // start round and wait for propose and prevote - startTestRound(cs1, height, round) + startTestRound(ctx, cs1, height, round) ensureNewRound(newRoundCh, height, round) ensureNewProposal(proposalCh, height, round) @@ -1855,17 +1925,17 @@ func TestStateHalt1(t *testing.T) { ensurePrevote(voteCh, height, round) - signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlockParts.Header(), vs2, vs3, vs4) + signAddVotes(ctx, config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlockParts.Header(), vs2, vs3, vs4) ensurePrecommit(voteCh, height, round) // the proposed block should now be locked and our precommit added - validatePrecommit(t, cs1, round, round, vss[0], propBlock.Hash(), propBlock.Hash()) + validatePrecommit(ctx, t, cs1, round, round, vss[0], propBlock.Hash(), propBlock.Hash()) // add precommits from the rest - signAddVotes(config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) // didnt receive proposal - signAddVotes(config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlockParts.Header(), vs3) + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, nil, types.PartSetHeader{}, vs2) // didnt receive proposal + signAddVotes(ctx, config, cs1, tmproto.PrecommitType, propBlock.Hash(), propBlockParts.Header(), vs3) // we receive this later, but vs3 might receive it earlier and with ours will go to commit! - precommit4 := signVote(vs4, config, tmproto.PrecommitType, propBlock.Hash(), propBlockParts.Header()) + precommit4 := signVote(ctx, vs4, config, tmproto.PrecommitType, propBlock.Hash(), propBlockParts.Header()) incrementRound(vs2, vs3, vs4) @@ -1885,7 +1955,7 @@ func TestStateHalt1(t *testing.T) { // go to prevote, prevote for locked block ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + validatePrevote(ctx, t, cs1, round, vss[0], rs.LockedBlock.Hash()) // now we receive the precommit from the previous round addVotes(cs1, precommit4) @@ -1898,9 +1968,11 @@ func TestStateHalt1(t *testing.T) { func TestStateOutputsBlockPartsStats(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // create dummy peer - cs, _, err := randState(config, log.TestingLogger(), 1) + cs, _, err := randState(ctx, config, log.TestingLogger(), 1) require.NoError(t, err) peerID, err := types.NewNodeID("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") require.NoError(t, err) @@ -1945,8 +2017,10 @@ func TestStateOutputsBlockPartsStats(t *testing.T) { func TestStateOutputVoteStats(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - cs, vss, err := randState(config, log.TestingLogger(), 2) + cs, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) // create dummy peer peerID, err := types.NewNodeID("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") @@ -1954,7 +2028,7 @@ func TestStateOutputVoteStats(t *testing.T) { randBytes := tmrand.Bytes(tmhash.Size) - vote := signVote(vss[1], config, tmproto.PrecommitType, randBytes, types.PartSetHeader{}) + vote := signVote(ctx, vss[1], config, tmproto.PrecommitType, randBytes, types.PartSetHeader{}) voteMessage := &VoteMessage{vote} cs.handleMsg(msgInfo{voteMessage, peerID}) @@ -1968,7 +2042,7 @@ func TestStateOutputVoteStats(t *testing.T) { // sending the vote for the bigger height incrementHeight(vss[1]) - vote = signVote(vss[1], config, tmproto.PrecommitType, randBytes, types.PartSetHeader{}) + vote = signVote(ctx, vss[1], config, tmproto.PrecommitType, randBytes, types.PartSetHeader{}) cs.handleMsg(msgInfo{&VoteMessage{vote}, peerID}) @@ -1982,20 +2056,26 @@ func TestStateOutputVoteStats(t *testing.T) { func TestSignSameVoteTwice(t *testing.T) { config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - _, vss, err := randState(config, log.TestingLogger(), 2) + _, vss, err := randState(ctx, config, log.TestingLogger(), 2) require.NoError(t, err) randBytes := tmrand.Bytes(tmhash.Size) - vote := signVote(vss[1], + vote := signVote( + ctx, + vss[1], config, tmproto.PrecommitType, randBytes, types.PartSetHeader{Total: 10, Hash: randBytes}, ) - vote2 := signVote(vss[1], + vote2 := signVote( + ctx, + vss[1], config, tmproto.PrecommitType, randBytes, @@ -2006,9 +2086,14 @@ func TestSignSameVoteTwice(t *testing.T) { } // subscribe subscribes test client to the given query and returns a channel with cap = 1. -func subscribe(t *testing.T, eventBus *eventbus.EventBus, q tmpubsub.Query) <-chan tmpubsub.Message { +func subscribe( + ctx context.Context, + t *testing.T, + eventBus *eventbus.EventBus, + q tmpubsub.Query, +) <-chan tmpubsub.Message { t.Helper() - sub, err := eventBus.SubscribeWithArgs(context.Background(), tmpubsub.SubscribeArgs{ + sub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: testSubscriber, Query: q, }) @@ -2018,8 +2103,11 @@ func subscribe(t *testing.T, eventBus *eventbus.EventBus, q tmpubsub.Query) <-ch ch := make(chan tmpubsub.Message) go func() { for { - next, err := sub.Next(context.Background()) + next, err := sub.Next(ctx) if err != nil { + if ctx.Err() != nil { + return + } t.Errorf("Subscription for %v unexpectedly terminated: %v", q, err) return } diff --git a/internal/consensus/ticker.go b/internal/consensus/ticker.go index 0226889c9..e8583932d 100644 --- a/internal/consensus/ticker.go +++ b/internal/consensus/ticker.go @@ -1,6 +1,7 @@ package consensus import ( + "context" "time" "github.com/tendermint/tendermint/libs/log" @@ -15,7 +16,7 @@ var ( // conditional on the height/round/step in the timeoutInfo. // The timeoutInfo.Duration may be non-positive. type TimeoutTicker interface { - Start() error + Start(context.Context) error Stop() error Chan() <-chan timeoutInfo // on which to receive a timeout ScheduleTimeout(ti timeoutInfo) // reset the timer @@ -47,8 +48,7 @@ func NewTimeoutTicker(logger log.Logger) TimeoutTicker { } // OnStart implements service.Service. It starts the timeout routine. -func (t *timeoutTicker) OnStart() error { - +func (t *timeoutTicker) OnStart(gctx context.Context) error { go t.timeoutRoutine() return nil diff --git a/internal/consensus/types/height_vote_set_test.go b/internal/consensus/types/height_vote_set_test.go index cafb1ed1a..94c36ee3e 100644 --- a/internal/consensus/types/height_vote_set_test.go +++ b/internal/consensus/types/height_vote_set_test.go @@ -29,23 +29,26 @@ func TestMain(m *testing.M) { } func TestPeerCatchupRounds(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + valSet, privVals := factory.RandValidatorSet(10, 1) hvs := NewHeightVoteSet(cfg.ChainID(), 1, valSet) - vote999_0 := makeVoteHR(t, 1, 0, 999, privVals) + vote999_0 := makeVoteHR(ctx, t, 1, 0, 999, privVals) added, err := hvs.AddVote(vote999_0, "peer1") if !added || err != nil { t.Error("Expected to successfully add vote from peer", added, err) } - vote1000_0 := makeVoteHR(t, 1, 0, 1000, privVals) + vote1000_0 := makeVoteHR(ctx, t, 1, 0, 1000, privVals) added, err = hvs.AddVote(vote1000_0, "peer1") if !added || err != nil { t.Error("Expected to successfully add vote from peer", added, err) } - vote1001_0 := makeVoteHR(t, 1, 0, 1001, privVals) + vote1001_0 := makeVoteHR(ctx, t, 1, 0, 1001, privVals) added, err = hvs.AddVote(vote1001_0, "peer1") if err != ErrGotVoteFromUnwantedRound { t.Errorf("expected GotVoteFromUnwantedRoundError, but got %v", err) @@ -61,9 +64,15 @@ func TestPeerCatchupRounds(t *testing.T) { } -func makeVoteHR(t *testing.T, height int64, valIndex, round int32, privVals []types.PrivValidator) *types.Vote { +func makeVoteHR( + ctx context.Context, + t *testing.T, + height int64, + valIndex, round int32, + privVals []types.PrivValidator, +) *types.Vote { privVal := privVals[valIndex] - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) if err != nil { panic(err) } @@ -82,7 +91,7 @@ func makeVoteHR(t *testing.T, height int64, valIndex, round int32, privVals []ty chainID := cfg.ChainID() v := vote.ToProto() - err = privVal.SignVote(context.Background(), chainID, v) + err = privVal.SignVote(ctx, chainID, v) if err != nil { panic(fmt.Sprintf("Error signing vote: %v", err)) } diff --git a/internal/consensus/wal.go b/internal/consensus/wal.go index 24fef294d..13f29a202 100644 --- a/internal/consensus/wal.go +++ b/internal/consensus/wal.go @@ -1,6 +1,7 @@ package consensus import ( + "context" "encoding/binary" "errors" "fmt" @@ -63,7 +64,7 @@ type WAL interface { SearchForEndHeight(height int64, options *WALSearchOptions) (rd io.ReadCloser, found bool, err error) // service methods - Start() error + Start(context.Context) error Stop() error Wait() } @@ -116,7 +117,7 @@ func (wal *BaseWAL) Group() *auto.Group { return wal.group } -func (wal *BaseWAL) OnStart() error { +func (wal *BaseWAL) OnStart(ctx context.Context) error { size, err := wal.group.Head.Size() if err != nil { return err @@ -125,7 +126,7 @@ func (wal *BaseWAL) OnStart() error { return err } } - err = wal.group.Start() + err = wal.group.Start(ctx) if err != nil { return err } @@ -159,10 +160,14 @@ func (wal *BaseWAL) FlushAndSync() error { func (wal *BaseWAL) OnStop() { wal.flushTicker.Stop() if err := wal.FlushAndSync(); err != nil { - wal.Logger.Error("error on flush data to disk", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + wal.Logger.Error("error on flush data to disk", "error", err) + } } if err := wal.group.Stop(); err != nil { - wal.Logger.Error("error trying to stop wal", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + wal.Logger.Error("error trying to stop wal", "error", err) + } } wal.group.Close() } @@ -423,6 +428,6 @@ func (nilWAL) FlushAndSync() error { return nil } func (nilWAL) SearchForEndHeight(height int64, options *WALSearchOptions) (rd io.ReadCloser, found bool, err error) { return nil, false, nil } -func (nilWAL) Start() error { return nil } -func (nilWAL) Stop() error { return nil } -func (nilWAL) Wait() {} +func (nilWAL) Start(context.Context) error { return nil } +func (nilWAL) Stop() error { return nil } +func (nilWAL) Wait() {} diff --git a/internal/consensus/wal_generator.go b/internal/consensus/wal_generator.go index 20cf0fae2..35a539d64 100644 --- a/internal/consensus/wal_generator.go +++ b/internal/consensus/wal_generator.go @@ -3,6 +3,7 @@ package consensus import ( "bufio" "bytes" + "context" "fmt" "io" mrand "math/rand" @@ -30,7 +31,7 @@ import ( // persistent kvstore application and special consensus wal instance // (byteBufferWAL) and waits until numBlocks are created. // If the node fails to produce given numBlocks, it returns an error. -func WALGenerateNBlocks(t *testing.T, wr io.Writer, numBlocks int) (err error) { +func WALGenerateNBlocks(ctx context.Context, t *testing.T, wr io.Writer, numBlocks int) (err error) { cfg := getConfig(t) app := kvstore.NewPersistentKVStoreApplication(filepath.Join(cfg.DBDir(), "wal_generator")) @@ -67,24 +68,15 @@ func WALGenerateNBlocks(t *testing.T, wr io.Writer, numBlocks int) (err error) { blockStore := store.NewBlockStore(blockStoreDB) proxyApp := proxy.NewAppConns(abciclient.NewLocalCreator(app), logger.With("module", "proxy"), proxy.NopMetrics()) - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { return fmt.Errorf("failed to start proxy app connections: %w", err) } - t.Cleanup(func() { - if err := proxyApp.Stop(); err != nil { - t.Error(err) - } - }) eventBus := eventbus.NewDefault(logger.With("module", "events")) - if err := eventBus.Start(); err != nil { + if err := eventBus.Start(ctx); err != nil { return fmt.Errorf("failed to start event bus: %w", err) } - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) + mempool := emptyMempool{} evpool := sm.EmptyEvidencePool{} blockExec := sm.NewBlockExecutor(stateStore, log.TestingLogger(), proxyApp.Consensus(), mempool, evpool, blockStore) @@ -105,7 +97,7 @@ func WALGenerateNBlocks(t *testing.T, wr io.Writer, numBlocks int) (err error) { consensusState.wal = wal - if err := consensusState.Start(); err != nil { + if err := consensusState.Start(ctx); err != nil { return fmt.Errorf("failed to start consensus state: %w", err) } @@ -124,11 +116,11 @@ func WALGenerateNBlocks(t *testing.T, wr io.Writer, numBlocks int) (err error) { } // WALWithNBlocks returns a WAL content with numBlocks. -func WALWithNBlocks(t *testing.T, numBlocks int) (data []byte, err error) { +func WALWithNBlocks(ctx context.Context, t *testing.T, numBlocks int) (data []byte, err error) { var b bytes.Buffer wr := bufio.NewWriter(&b) - if err := WALGenerateNBlocks(t, wr, numBlocks); err != nil { + if err := WALGenerateNBlocks(ctx, t, wr, numBlocks); err != nil { return []byte{}, err } @@ -227,6 +219,6 @@ func (w *byteBufferWAL) SearchForEndHeight( return nil, false, nil } -func (w *byteBufferWAL) Start() error { return nil } -func (w *byteBufferWAL) Stop() error { return nil } -func (w *byteBufferWAL) Wait() {} +func (w *byteBufferWAL) Start(context.Context) error { return nil } +func (w *byteBufferWAL) Stop() error { return nil } +func (w *byteBufferWAL) Wait() {} diff --git a/internal/consensus/wal_test.go b/internal/consensus/wal_test.go index 6c1feb670..c0290fcf8 100644 --- a/internal/consensus/wal_test.go +++ b/internal/consensus/wal_test.go @@ -2,6 +2,7 @@ package consensus import ( "bytes" + "context" "path/filepath" "testing" @@ -18,15 +19,16 @@ import ( tmtypes "github.com/tendermint/tendermint/types" ) -const ( - walTestFlushInterval = time.Duration(100) * time.Millisecond -) +const walTestFlushInterval = 100 * time.Millisecond func TestWALTruncate(t *testing.T) { walDir := t.TempDir() walFile := filepath.Join(walDir, "wal") logger := log.TestingLogger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // this magic number 4K can truncate the content when RotateFile. // defaultHeadSizeLimit(10M) is hard to simulate. // this magic number 1 * time.Millisecond make RotateFile check frequently. @@ -36,21 +38,14 @@ func TestWALTruncate(t *testing.T) { autofile.GroupCheckDuration(1*time.Millisecond), ) require.NoError(t, err) - err = wal.Start() + err = wal.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := wal.Stop(); err != nil { - t.Error(err) - } - // wait for the wal to finish shutting down so we - // can safely remove the directory - wal.Wait() - }) + t.Cleanup(wal.Wait) // 60 block's size nearly 70K, greater than group's headBuf size(4096 * 10), // when headBuf is full, truncate content will Flush to the file. at this // time, RotateFile is called, truncate content exist in each file. - err = WALGenerateNBlocks(t, wal.Group(), 60) + err = WALGenerateNBlocks(ctx, t, wal.Group(), 60) require.NoError(t, err) time.Sleep(1 * time.Millisecond) // wait groupCheckDuration, make sure RotateFile run @@ -105,18 +100,14 @@ func TestWALWrite(t *testing.T) { walDir := t.TempDir() walFile := filepath.Join(walDir, "wal") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + wal, err := NewWAL(log.TestingLogger(), walFile) require.NoError(t, err) - err = wal.Start() + err = wal.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := wal.Stop(); err != nil { - t.Error(err) - } - // wait for the wal to finish shutting down so we - // can safely remove the directory - wal.Wait() - }) + t.Cleanup(wal.Wait) // 1) Write returns an error if msg is too big msg := &BlockPartMessage{ @@ -142,7 +133,10 @@ func TestWALWrite(t *testing.T) { } func TestWALSearchForEndHeight(t *testing.T) { - walBody, err := WALWithNBlocks(t, 6) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + walBody, err := WALWithNBlocks(ctx, t, 6) if err != nil { t.Fatal(err) } @@ -171,18 +165,21 @@ func TestWALPeriodicSync(t *testing.T) { walFile := filepath.Join(walDir, "wal") wal, err := NewWAL(log.TestingLogger(), walFile, autofile.GroupCheckDuration(1*time.Millisecond)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + require.NoError(t, err) wal.SetFlushInterval(walTestFlushInterval) // Generate some data - err = WALGenerateNBlocks(t, wal.Group(), 5) + err = WALGenerateNBlocks(ctx, t, wal.Group(), 5) require.NoError(t, err) // We should have data in the buffer now assert.NotZero(t, wal.Group().Buffered()) - require.NoError(t, wal.Start()) + require.NoError(t, wal.Start(ctx)) t.Cleanup(func() { if err := wal.Stop(); err != nil { t.Error(err) diff --git a/internal/eventbus/event_bus.go b/internal/eventbus/event_bus.go index 95c8876d4..4b28f6fcd 100644 --- a/internal/eventbus/event_bus.go +++ b/internal/eventbus/event_bus.go @@ -2,6 +2,7 @@ package eventbus import ( "context" + "errors" "fmt" "strings" @@ -38,13 +39,15 @@ func NewDefault(l log.Logger) *EventBus { return b } -func (b *EventBus) OnStart() error { - return b.pubsub.Start() +func (b *EventBus) OnStart(ctx context.Context) error { + return b.pubsub.Start(ctx) } func (b *EventBus) OnStop() { if err := b.pubsub.Stop(); err != nil { - b.pubsub.Logger.Error("error trying to stop eventBus", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + b.pubsub.Logger.Error("error trying to stop eventBus", "error", err) + } } } diff --git a/internal/eventbus/event_bus_test.go b/internal/eventbus/event_bus_test.go index 64e6f3344..3e9069718 100644 --- a/internal/eventbus/event_bus_test.go +++ b/internal/eventbus/event_bus_test.go @@ -19,14 +19,12 @@ import ( ) func TestEventBusPublishEventTx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) tx := types.Tx("foo") result := abci.ResponseDeliverTx{ @@ -37,7 +35,6 @@ func TestEventBusPublishEventTx(t *testing.T) { } // PublishEventTx adds 3 composite keys, so the query below should work - ctx := context.Background() query := fmt.Sprintf("tm.event='Tx' AND tx.height=1 AND tx.hash='%X' AND testType.baz=1", tx.Hash()) txsSub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: "test", @@ -76,14 +73,11 @@ func TestEventBusPublishEventTx(t *testing.T) { } func TestEventBusPublishEventNewBlock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) block := types.MakeBlock(0, []types.Tx{}, nil, []types.Evidence{}) blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: block.MakePartSet(types.BlockPartSizeBytes).Header()} @@ -99,7 +93,6 @@ func TestEventBusPublishEventNewBlock(t *testing.T) { } // PublishEventNewBlock adds the tm.event compositeKey, so the query below should work - ctx := context.Background() query := "tm.event='NewBlock' AND testType.baz=1 AND testType.foz=2" blocksSub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: "test", @@ -136,14 +129,11 @@ func TestEventBusPublishEventNewBlock(t *testing.T) { } func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) tx := types.Tx("foo") result := abci.ResponseDeliverTx{ @@ -203,54 +193,65 @@ func TestEventBusPublishEventTxDuplicateKeys(t *testing.T) { } for i, tc := range testCases { - ctx := context.Background() - sub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ - ClientID: fmt.Sprintf("client-%d", i), - Query: tmquery.MustParse(tc.query), - }) - require.NoError(t, err) + var name string - gotResult := make(chan bool, 1) - go func() { - defer close(gotResult) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - msg, err := sub.Next(ctx) - if err == nil { - data := msg.Data().(types.EventDataTx) - assert.Equal(t, int64(1), data.Height) - assert.Equal(t, uint32(0), data.Index) - assert.EqualValues(t, tx, data.Tx) - assert.Equal(t, result, data.Result) - gotResult <- true - } - }() - - assert.NoError(t, eventBus.PublishEventTx(types.EventDataTx{ - TxResult: abci.TxResult{ - Height: 1, - Index: 0, - Tx: tx, - Result: result, - }, - })) - - if got := <-gotResult; got != tc.expectResults { - require.Failf(t, "Wrong transaction result", - "got a tx: %v, wanted a tx: %v", got, tc.expectResults) + if tc.expectResults { + name = fmt.Sprintf("ExpetedResultsCase%d", i) + } else { + name = fmt.Sprintf("NoResultsCase%d", i) } + + t.Run(name, func(t *testing.T) { + + sub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ + ClientID: fmt.Sprintf("client-%d", i), + Query: tmquery.MustParse(tc.query), + }) + require.NoError(t, err) + + gotResult := make(chan bool, 1) + go func() { + defer close(gotResult) + tctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + msg, err := sub.Next(tctx) + if err == nil { + data := msg.Data().(types.EventDataTx) + assert.Equal(t, int64(1), data.Height) + assert.Equal(t, uint32(0), data.Index) + assert.EqualValues(t, tx, data.Tx) + assert.Equal(t, result, data.Result) + gotResult <- true + } + }() + + assert.NoError(t, eventBus.PublishEventTx(types.EventDataTx{ + TxResult: abci.TxResult{ + Height: 1, + Index: 0, + Tx: tx, + Result: result, + }, + })) + + require.NoError(t, ctx.Err(), "context should not have been canceled") + + if got := <-gotResult; got != tc.expectResults { + require.Failf(t, "Wrong transaction result", + "got a tx: %v, wanted a tx: %v", got, tc.expectResults) + } + }) + } } func TestEventBusPublishEventNewBlockHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) block := types.MakeBlock(0, []types.Tx{}, nil, []types.Evidence{}) resultBeginBlock := abci.ResponseBeginBlock{ @@ -265,7 +266,6 @@ func TestEventBusPublishEventNewBlockHeader(t *testing.T) { } // PublishEventNewBlockHeader adds the tm.event compositeKey, so the query below should work - ctx := context.Background() query := "tm.event='NewBlockHeader' AND testType.baz=1 AND testType.foz=2" headersSub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: "test", @@ -300,18 +300,15 @@ func TestEventBusPublishEventNewBlockHeader(t *testing.T) { } func TestEventBusPublishEventNewEvidence(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) ev := types.NewMockDuplicateVoteEvidence(1, time.Now(), "test-chain-id") - ctx := context.Background() const query = `tm.event='NewEvidence'` evSub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: "test", @@ -344,18 +341,15 @@ func TestEventBusPublishEventNewEvidence(t *testing.T) { } func TestEventBusPublish(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eventBus := eventbus.NewDefault(log.TestingLogger()) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) const numEventsExpected = 14 - ctx := context.Background() sub, err := eventBus.SubscribeWithArgs(ctx, tmpubsub.SubscribeArgs{ ClientID: "test", Query: tmquery.Empty{}, @@ -434,8 +428,11 @@ func benchmarkEventBus(numClients int, randQueries bool, randEvents bool, b *tes // for random* functions mrand.Seed(time.Now().Unix()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eventBus := eventbus.NewDefault(log.TestingLogger()) // set buffer capacity to 0 so we are not testing cache - err := eventBus.Start() + err := eventBus.Start(ctx) if err != nil { b.Error(err) } @@ -445,7 +442,6 @@ func benchmarkEventBus(numClients int, randQueries bool, randEvents bool, b *tes } }) - ctx := context.Background() q := types.EventQueryNewBlock for i := 0; i < numClients; i++ { diff --git a/internal/evidence/reactor.go b/internal/evidence/reactor.go index c2f25bd36..4e37e1d17 100644 --- a/internal/evidence/reactor.go +++ b/internal/evidence/reactor.go @@ -1,6 +1,7 @@ package evidence import ( + "context" "fmt" "runtime/debug" "sync" @@ -81,7 +82,7 @@ func NewReactor( // envelopes on each. In addition, it also listens for peer updates and handles // messages on that p2p channel accordingly. The caller must be sure to execute // OnStop to ensure the outbound p2p Channels are closed. No error is returned. -func (r *Reactor) OnStart() error { +func (r *Reactor) OnStart(ctx context.Context) error { go r.processEvidenceCh() go r.processPeerUpdates() diff --git a/internal/evidence/reactor_test.go b/internal/evidence/reactor_test.go index d1ae55803..764450cd6 100644 --- a/internal/evidence/reactor_test.go +++ b/internal/evidence/reactor_test.go @@ -1,6 +1,7 @@ package evidence_test import ( + "context" "encoding/hex" "math/rand" "sync" @@ -44,7 +45,7 @@ type reactorTestSuite struct { numStateStores int } -func setup(t *testing.T, stateStores []sm.Store, chBuf uint) *reactorTestSuite { +func setup(ctx context.Context, t *testing.T, stateStores []sm.Store, chBuf uint) *reactorTestSuite { t.Helper() pID := make([]byte, 16) @@ -55,7 +56,7 @@ func setup(t *testing.T, stateStores []sm.Store, chBuf uint) *reactorTestSuite { rts := &reactorTestSuite{ numStateStores: numStateStores, logger: log.TestingLogger().With("testCase", t.Name()), - network: p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: numStateStores}), + network: p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: numStateStores}), reactors: make(map[types.NodeID]*evidence.Reactor, numStateStores), pools: make(map[types.NodeID]*evidence.Pool, numStateStores), peerUpdates: make(map[types.NodeID]*p2p.PeerUpdates, numStateStores), @@ -93,7 +94,7 @@ func setup(t *testing.T, stateStores []sm.Store, chBuf uint) *reactorTestSuite { rts.peerUpdates[nodeID], rts.pools[nodeID]) - require.NoError(t, rts.reactors[nodeID].Start()) + require.NoError(t, rts.reactors[nodeID].Start(ctx)) require.True(t, rts.reactors[nodeID].IsRunning()) idx++ @@ -233,13 +234,16 @@ func createEvidenceList( } func TestReactorMultiDisconnect(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + val := types.NewMockPV() height := int64(numEvidence) + 10 stateDB1 := initializeValidatorState(t, val, height) stateDB2 := initializeValidatorState(t, val, height) - rts := setup(t, []sm.Store{stateDB1, stateDB2}, 20) + rts := setup(ctx, t, []sm.Store{stateDB1, stateDB2}, 20) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -281,7 +285,10 @@ func TestReactorBroadcastEvidence(t *testing.T) { stateDBs[i] = initializeValidatorState(t, val, height) } - rts := setup(t, stateDBs, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, stateDBs, 0) rts.start(t) // Create a series of fixtures where each suite contains a reactor and @@ -335,7 +342,10 @@ func TestReactorBroadcastEvidence_Lagging(t *testing.T) { stateDB1 := initializeValidatorState(t, val, height1) stateDB2 := initializeValidatorState(t, val, height2) - rts := setup(t, []sm.Store{stateDB1, stateDB2}, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, []sm.Store{stateDB1, stateDB2}, 100) rts.start(t) primary := rts.nodes[0] @@ -368,7 +378,10 @@ func TestReactorBroadcastEvidence_Pending(t *testing.T) { stateDB1 := initializeValidatorState(t, val, height) stateDB2 := initializeValidatorState(t, val, height) - rts := setup(t, []sm.Store{stateDB1, stateDB2}, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, []sm.Store{stateDB1, stateDB2}, 100) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -405,7 +418,10 @@ func TestReactorBroadcastEvidence_Committed(t *testing.T) { stateDB1 := initializeValidatorState(t, val, height) stateDB2 := initializeValidatorState(t, val, height) - rts := setup(t, []sm.Store{stateDB1, stateDB2}, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, []sm.Store{stateDB1, stateDB2}, 0) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -460,7 +476,10 @@ func TestReactorBroadcastEvidence_FullyConnected(t *testing.T) { stateDBs[i] = initializeValidatorState(t, val, height) } - rts := setup(t, stateDBs, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, stateDBs, 0) rts.start(t) evList := createEvidenceList(t, rts.pools[rts.network.RandomNode().NodeID], val, numEvidence) diff --git a/internal/inspect/inspect.go b/internal/inspect/inspect.go index 59f0d92d2..3bc744e4c 100644 --- a/internal/inspect/inspect.go +++ b/internal/inspect/inspect.go @@ -85,26 +85,18 @@ func NewFromConfig(logger log.Logger, cfg *config.Config) (*Inspector, error) { // Run starts the Inspector servers and blocks until the servers shut down. The passed // in context is used to control the lifecycle of the servers. func (ins *Inspector) Run(ctx context.Context) error { - err := ins.eventBus.Start() + err := ins.eventBus.Start(ctx) if err != nil { return fmt.Errorf("error starting event bus: %s", err) } - defer func() { - err := ins.eventBus.Stop() - if err != nil { - ins.logger.Error("event bus stopped with error", "err", err) - } - }() - err = ins.indexerService.Start() + defer ins.eventBus.Wait() + + err = ins.indexerService.Start(ctx) if err != nil { return fmt.Errorf("error starting indexer service: %s", err) } - defer func() { - err := ins.indexerService.Stop() - if err != nil { - ins.logger.Error("indexer service stopped with error", "err", err) - } - }() + defer ins.indexerService.Wait() + return startRPCServers(ctx, ins.config, ins.logger, ins.routes) } diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 5527bf2ac..15a555ab0 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -101,7 +101,7 @@ func TestBlock(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - resultBlock, err := cli.Block(context.Background(), &testHeight) + resultBlock, err := cli.Block(ctx, &testHeight) require.NoError(t, err) require.Equal(t, testBlock.Height, resultBlock.Block.Height) require.Equal(t, testBlock.LastCommitHash, resultBlock.Block.LastCommitHash) @@ -153,7 +153,7 @@ func TestTxSearch(t *testing.T) { require.NoError(t, err) var page = 1 - resultTxSearch, err := cli.TxSearch(context.Background(), testQuery, false, &page, &page, "") + resultTxSearch, err := cli.TxSearch(ctx, testQuery, false, &page, &page, "") require.NoError(t, err) require.Len(t, resultTxSearch.Txs, 1) require.Equal(t, types.Tx(testTx), resultTxSearch.Txs[0].Tx) @@ -199,7 +199,7 @@ func TestTx(t *testing.T) { cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - res, err := cli.Tx(context.Background(), testHash, false) + res, err := cli.Tx(ctx, testHash, false) require.NoError(t, err) require.Equal(t, types.Tx(testTx), res.Tx) @@ -247,7 +247,7 @@ func TestConsensusParams(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - params, err := cli.ConsensusParams(context.Background(), &testHeight) + params, err := cli.ConsensusParams(ctx, &testHeight) require.NoError(t, err) require.Equal(t, params.ConsensusParams.Block.MaxGas, testMaxGas) @@ -300,7 +300,7 @@ func TestBlockResults(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - res, err := cli.BlockResults(context.Background(), &testHeight) + res, err := cli.BlockResults(ctx, &testHeight) require.NoError(t, err) require.Equal(t, res.TotalGasUsed, testGasUsed) @@ -348,7 +348,7 @@ func TestCommit(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - res, err := cli.Commit(context.Background(), &testHeight) + res, err := cli.Commit(ctx, &testHeight) require.NoError(t, err) require.NotNil(t, res) require.Equal(t, res.SignedHeader.Commit.Round, testRound) @@ -402,7 +402,7 @@ func TestBlockByHash(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - res, err := cli.BlockByHash(context.Background(), testHash) + res, err := cli.BlockByHash(ctx, testHash) require.NoError(t, err) require.NotNil(t, res) require.Equal(t, []byte(res.BlockID.Hash), testHash) @@ -455,7 +455,7 @@ func TestBlockchain(t *testing.T) { requireConnect(t, rpcConfig.ListenAddress, 20) cli, err := httpclient.New(rpcConfig.ListenAddress) require.NoError(t, err) - res, err := cli.BlockchainInfo(context.Background(), 0, 100) + res, err := cli.BlockchainInfo(ctx, 0, 100) require.NoError(t, err) require.NotNil(t, res) require.Equal(t, testBlockHash, []byte(res.BlockMetas[0].BlockID.Hash)) @@ -511,7 +511,7 @@ func TestValidators(t *testing.T) { testPage := 1 testPerPage := 100 - res, err := cli.Validators(context.Background(), &testHeight, &testPage, &testPerPage) + res, err := cli.Validators(ctx, &testHeight, &testPage, &testPerPage) require.NoError(t, err) require.NotNil(t, res) require.Equal(t, testVotingPower, res.Validators[0].VotingPower) @@ -571,7 +571,7 @@ func TestBlockSearch(t *testing.T) { testPage := 1 testPerPage := 100 testOrderBy := "desc" - res, err := cli.BlockSearch(context.Background(), testQuery, &testPage, &testPerPage, testOrderBy) + res, err := cli.BlockSearch(ctx, testQuery, &testPage, &testPerPage, testOrderBy) require.NoError(t, err) require.NotNil(t, res) require.Equal(t, testBlockHash, []byte(res.Blocks[0].BlockID.Hash)) diff --git a/internal/libs/autofile/cmd/logjack.go b/internal/libs/autofile/cmd/logjack.go index 1aa8b6a11..0f412a366 100644 --- a/internal/libs/autofile/cmd/logjack.go +++ b/internal/libs/autofile/cmd/logjack.go @@ -1,15 +1,17 @@ package main import ( + "context" "flag" "fmt" "io" "os" + "os/signal" "strconv" "strings" + "syscall" auto "github.com/tendermint/tendermint/internal/libs/autofile" - tmos "github.com/tendermint/tendermint/libs/os" ) const Version = "0.0.1" @@ -32,21 +34,10 @@ func parseFlags() (headPath string, chopSize int64, limitSize int64, version boo return } -type fmtLogger struct{} - -func (fmtLogger) Info(msg string, keyvals ...interface{}) { - strs := make([]string, len(keyvals)) - for i, kv := range keyvals { - strs[i] = fmt.Sprintf("%v", kv) - } - fmt.Printf("%s %s\n", msg, strings.Join(strs, ",")) -} - func main() { - // Stop upon receiving SIGTERM or CTRL-C. - tmos.TrapSignal(fmtLogger{}, func() { - fmt.Println("logjack shutting down") - }) + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM) + defer cancel() + defer func() { fmt.Println("logjack shutting down") }() // Read options headPath, chopSize, limitSize, version := parseFlags() @@ -62,7 +53,7 @@ func main() { os.Exit(1) } - if err = group.Start(); err != nil { + if err = group.Start(ctx); err != nil { fmt.Printf("logjack couldn't start with file %v\n", headPath) os.Exit(1) } diff --git a/internal/libs/autofile/group.go b/internal/libs/autofile/group.go index 23f27c59b..0e208d8e9 100644 --- a/internal/libs/autofile/group.go +++ b/internal/libs/autofile/group.go @@ -2,6 +2,7 @@ package autofile import ( "bufio" + "context" "errors" "fmt" "io" @@ -135,7 +136,7 @@ func GroupTotalSizeLimit(limit int64) func(*Group) { // OnStart implements service.Service by starting the goroutine that checks file // and group limits. -func (g *Group) OnStart() error { +func (g *Group) OnStart(ctx context.Context) error { g.ticker = time.NewTicker(g.groupCheckDuration) go g.processTicks() return nil diff --git a/internal/mempool/mempool_bench_test.go b/internal/mempool/mempool_bench_test.go index 8a7938fd5..ed4740011 100644 --- a/internal/mempool/mempool_bench_test.go +++ b/internal/mempool/mempool_bench_test.go @@ -11,7 +11,10 @@ import ( ) func BenchmarkTxMempool_CheckTx(b *testing.B) { - txmp := setup(b, 10000) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, b, 10000) rng := rand.New(rand.NewSource(time.Now().UnixNano())) b.ResetTimer() diff --git a/internal/mempool/mempool_test.go b/internal/mempool/mempool_test.go index 428204214..f06ee18d9 100644 --- a/internal/mempool/mempool_test.go +++ b/internal/mempool/mempool_test.go @@ -72,9 +72,12 @@ func (app *application) CheckTx(req abci.RequestCheckTx) abci.ResponseCheckTx { } } -func setup(t testing.TB, cacheSize int, options ...TxMempoolOption) *TxMempool { +func setup(ctx context.Context, t testing.TB, cacheSize int, options ...TxMempoolOption) *TxMempool { t.Helper() + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + app := &application{kvstore.NewApplication()} cc := abciclient.NewLocalCreator(app) logger := log.TestingLogger() @@ -84,11 +87,12 @@ func setup(t testing.TB, cacheSize int, options ...TxMempoolOption) *TxMempool { cfg.Mempool.CacheSize = cacheSize appConnMem, err := cc(logger) require.NoError(t, err) - require.NoError(t, appConnMem.Start()) + require.NoError(t, appConnMem.Start(ctx)) t.Cleanup(func() { os.RemoveAll(cfg.RootDir) - require.NoError(t, appConnMem.Stop()) + cancel() + appConnMem.Wait() }) return NewTxMempool(logger.With("test", t.Name()), cfg.Mempool, appConnMem, 0, options...) @@ -128,7 +132,10 @@ func convertTex(in []testTx) types.Txs { } func TestTxMempool_TxsAvailable(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) txmp.EnableTxsAvailable() ensureNoTxFire := func() { @@ -182,7 +189,10 @@ func TestTxMempool_TxsAvailable(t *testing.T) { } func TestTxMempool_Size(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) txs := checkTxs(t, txmp, 100, 0) require.Equal(t, len(txs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -206,7 +216,10 @@ func TestTxMempool_Size(t *testing.T) { } func TestTxMempool_Flush(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) txs := checkTxs(t, txmp, 100, 0) require.Equal(t, len(txs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -231,7 +244,10 @@ func TestTxMempool_Flush(t *testing.T) { } func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) tTxs := checkTxs(t, txmp, 100, 0) // all txs request 1 gas unit require.Equal(t, len(tTxs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -281,7 +297,10 @@ func TestTxMempool_ReapMaxBytesMaxGas(t *testing.T) { } func TestTxMempool_ReapMaxTxs(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) tTxs := checkTxs(t, txmp, 100, 0) require.Equal(t, len(tTxs), txmp.Size()) require.Equal(t, int64(5690), txmp.SizeBytes()) @@ -330,7 +349,10 @@ func TestTxMempool_ReapMaxTxs(t *testing.T) { } func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { - txmp := setup(t, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 0) rng := rand.New(rand.NewSource(time.Now().UnixNano())) tx := make([]byte, txmp.config.MaxTxBytes+1) @@ -347,7 +369,10 @@ func TestTxMempool_CheckTxExceedsMaxSize(t *testing.T) { } func TestTxMempool_CheckTxSamePeer(t *testing.T) { - txmp := setup(t, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 100) peerID := uint16(1) rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -362,7 +387,10 @@ func TestTxMempool_CheckTxSamePeer(t *testing.T) { } func TestTxMempool_CheckTxSameSender(t *testing.T) { - txmp := setup(t, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 100) peerID := uint16(1) rng := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -384,7 +412,10 @@ func TestTxMempool_CheckTxSameSender(t *testing.T) { } func TestTxMempool_ConcurrentTxs(t *testing.T) { - txmp := setup(t, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 100) rng := rand.New(rand.NewSource(time.Now().UnixNano())) checkTxDone := make(chan struct{}) @@ -448,7 +479,10 @@ func TestTxMempool_ConcurrentTxs(t *testing.T) { } func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { - txmp := setup(t, 500) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txmp := setup(ctx, t, 500) txmp.height = 100 txmp.config.TTLNumBlocks = 10 @@ -498,6 +532,9 @@ func TestTxMempool_ExpiredTxs_NumBlocks(t *testing.T) { } func TestTxMempool_CheckTxPostCheckError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cases := []struct { name string err error @@ -514,10 +551,13 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { for _, tc := range cases { testCase := tc t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + postCheckFn := func(_ types.Tx, _ *abci.ResponseCheckTx) error { return testCase.err } - txmp := setup(t, 0, WithPostCheck(postCheckFn)) + txmp := setup(ctx, t, 0, WithPostCheck(postCheckFn)) rng := rand.New(rand.NewSource(time.Now().UnixNano())) tx := make([]byte, txmp.config.MaxTxBytes-1) _, err := rng.Read(tx) @@ -532,7 +572,7 @@ func TestTxMempool_CheckTxPostCheckError(t *testing.T) { } require.Equal(t, expectedErrString, checkTxRes.CheckTx.MempoolError) } - require.NoError(t, txmp.CheckTx(context.Background(), tx, callback, TxInfo{SenderID: 0})) + require.NoError(t, txmp.CheckTx(ctx, tx, callback, TxInfo{SenderID: 0})) }) } } diff --git a/internal/mempool/reactor.go b/internal/mempool/reactor.go index 6d86b3712..2ddb44e7a 100644 --- a/internal/mempool/reactor.go +++ b/internal/mempool/reactor.go @@ -112,7 +112,7 @@ func GetChannelDescriptor(cfg *config.MempoolConfig) *p2p.ChannelDescriptor { // envelopes on each. In addition, it also listens for peer updates and handles // messages on that p2p channel accordingly. The caller must be sure to execute // OnStop to ensure the outbound p2p Channels are closed. -func (r *Reactor) OnStart() error { +func (r *Reactor) OnStart(ctx context.Context) error { if !r.cfg.Broadcast { r.Logger.Info("tx broadcasting is disabled") } diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index 68753cc69..4456424b5 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -36,7 +36,7 @@ type reactorTestSuite struct { nodes []types.NodeID } -func setupReactors(t *testing.T, numNodes int, chBuf uint) *reactorTestSuite { +func setupReactors(ctx context.Context, t *testing.T, numNodes int, chBuf uint) *reactorTestSuite { t.Helper() cfg, err := config.ResetTestRoot(strings.ReplaceAll(t.Name(), "/", "|")) @@ -45,7 +45,7 @@ func setupReactors(t *testing.T, numNodes int, chBuf uint) *reactorTestSuite { rts := &reactorTestSuite{ logger: log.TestingLogger().With("testCase", t.Name()), - network: p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: numNodes}), + network: p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: numNodes}), reactors: make(map[types.NodeID]*Reactor, numNodes), mempoolChannels: make(map[types.NodeID]*p2p.Channel, numNodes), mempools: make(map[types.NodeID]*TxMempool, numNodes), @@ -60,7 +60,7 @@ func setupReactors(t *testing.T, numNodes int, chBuf uint) *reactorTestSuite { for nodeID := range rts.network.Nodes { rts.kvstores[nodeID] = kvstore.NewApplication() - mempool := setup(t, 0) + mempool := setup(ctx, t, 0) rts.mempools[nodeID] = mempool rts.peerChans[nodeID] = make(chan p2p.PeerUpdate) @@ -78,7 +78,7 @@ func setupReactors(t *testing.T, numNodes int, chBuf uint) *reactorTestSuite { rts.nodes = append(rts.nodes, nodeID) - require.NoError(t, rts.reactors[nodeID].Start()) + require.NoError(t, rts.reactors[nodeID].Start(ctx)) require.True(t, rts.reactors[nodeID].IsRunning()) } @@ -147,8 +147,11 @@ func (rts *reactorTestSuite) waitForTxns(t *testing.T, txs []types.Tx, ids ...ty } func TestReactorBroadcastDoesNotPanic(t *testing.T) { - numNodes := 2 - rts := setupReactors(t, numNodes, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numNodes = 2 + rts := setupReactors(ctx, t, numNodes, 0) observePanic := func(r interface{}) { t.Fatal("panic detected in reactor") @@ -192,8 +195,10 @@ func TestReactorBroadcastDoesNotPanic(t *testing.T) { func TestReactorBroadcastTxs(t *testing.T) { numTxs := 1000 numNodes := 10 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - rts := setupReactors(t, numNodes, 0) + rts := setupReactors(ctx, t, numNodes, 0) primary := rts.nodes[0] secondaries := rts.nodes[1:] @@ -215,7 +220,10 @@ func TestReactorConcurrency(t *testing.T) { numTxs := 5 numNodes := 2 - rts := setupReactors(t, numNodes, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setupReactors(ctx, t, numNodes, 0) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -273,7 +281,10 @@ func TestReactorNoBroadcastToSender(t *testing.T) { numTxs := 1000 numNodes := 2 - rts := setupReactors(t, numNodes, uint(numTxs)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setupReactors(ctx, t, numNodes, uint(numTxs)) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -296,7 +307,10 @@ func TestReactor_MaxTxBytes(t *testing.T) { numNodes := 2 cfg := config.TestConfig() - rts := setupReactors(t, numNodes, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setupReactors(ctx, t, numNodes, 0) primary := rts.nodes[0] secondary := rts.nodes[1] @@ -305,7 +319,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { // second reactor. tx1 := tmrand.Bytes(cfg.Mempool.MaxTxBytes) err := rts.reactors[primary].mempool.CheckTx( - context.Background(), + ctx, tx1, nil, TxInfo{ @@ -321,7 +335,7 @@ func TestReactor_MaxTxBytes(t *testing.T) { // broadcast a tx, which is beyond the max size and ensure it's not sent tx2 := tmrand.Bytes(cfg.Mempool.MaxTxBytes + 1) - err = rts.mempools[primary].CheckTx(context.Background(), tx2, nil, TxInfo{SenderID: UnknownPeerID}) + err = rts.mempools[primary].CheckTx(ctx, tx2, nil, TxInfo{SenderID: UnknownPeerID}) require.Error(t, err) rts.assertMempoolChannelsDrained(t) @@ -330,7 +344,11 @@ func TestReactor_MaxTxBytes(t *testing.T) { func TestDontExhaustMaxActiveIDs(t *testing.T) { // we're creating a single node network, but not starting the // network. - rts := setupReactors(t, 1, MaxActiveIDs+1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setupReactors(ctx, t, 1, MaxActiveIDs+1) nodeID := rts.nodes[0] @@ -395,7 +413,10 @@ func TestBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { t.Skip("skipping test in short mode") } - rts := setupReactors(t, 2, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setupReactors(ctx, t, 2, 0) primary := rts.nodes[0] secondary := rts.nodes[1] diff --git a/internal/p2p/address_test.go b/internal/p2p/address_test.go index 2745faf73..d5f9e498e 100644 --- a/internal/p2p/address_test.go +++ b/internal/p2p/address_test.go @@ -1,6 +1,7 @@ package p2p_test import ( + "context" "net" "strings" "testing" @@ -204,6 +205,9 @@ func TestParseNodeAddress(t *testing.T) { func TestNodeAddress_Resolve(t *testing.T) { id := types.NodeID("00112233445566778899aabbccddeeff00112233") + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + testcases := []struct { address p2p.NodeAddress expect p2p.Endpoint @@ -275,6 +279,9 @@ func TestNodeAddress_Resolve(t *testing.T) { for _, tc := range testcases { tc := tc t.Run(tc.address.String(), func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + endpoints, err := tc.address.Resolve(ctx) if !tc.ok { require.Error(t, err) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index a7efe54e3..9fb330286 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -2,6 +2,7 @@ package conn import ( "bufio" + "context" "errors" "fmt" "io" @@ -209,8 +210,8 @@ func NewMConnectionWithConfig( } // OnStart implements BaseService -func (c *MConnection) OnStart() error { - if err := c.BaseService.OnStart(); err != nil { +func (c *MConnection) OnStart(ctx context.Context) error { + if err := c.BaseService.OnStart(ctx); err != nil { return err } c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle) diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index 5c4cdb156..dc198d8bd 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -1,6 +1,7 @@ package conn import ( + "context" "encoding/hex" "net" "testing" @@ -47,8 +48,11 @@ func TestMConnectionSendFlushStop(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + clientConn := createTestMConnection(log.TestingLogger(), client) - err := clientConn.Start() + err := clientConn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, clientConn)) @@ -81,8 +85,11 @@ func TestMConnectionSend(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createTestMConnection(log.TestingLogger(), client) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -118,13 +125,17 @@ func TestMConnectionReceive(t *testing.T) { errorsCh <- r } logger := log.TestingLogger() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn1 := createMConnectionWithCallbacks(logger, client, onReceive, onError) - err := mconn1.Start() + err := mconn1.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn1)) mconn2 := createTestMConnection(logger, server) - err = mconn2.Start() + err = mconn2.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn2)) @@ -153,8 +164,12 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { onError := func(r interface{}) { errorsCh <- r } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -191,8 +206,11 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { onError := func(r interface{}) { errorsCh <- r } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -245,8 +263,11 @@ func TestMConnectionMultiplePings(t *testing.T) { onError := func(r interface{}) { errorsCh <- r } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -292,8 +313,12 @@ func TestMConnectionPingPongs(t *testing.T) { onError := func(r interface{}) { errorsCh <- r } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -349,8 +374,11 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { onError := func(r interface{}) { errorsCh <- r } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -369,7 +397,11 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { } } -func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) (*MConnection, *MConnection) { +func newClientAndServerConnsForReadErrors( + ctx context.Context, + t *testing.T, + chOnErr chan struct{}, +) (*MConnection, *MConnection) { server, client := NetPipe() onReceive := func(chID ChannelID, msgBytes []byte) {} @@ -381,8 +413,9 @@ func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) ( {ID: 0x02, Priority: 1, SendQueueCapacity: 1}, } logger := log.TestingLogger() + mconnClient := NewMConnection(logger.With("module", "client"), client, chDescs, onReceive, onError) - err := mconnClient.Start() + err := mconnClient.Start(ctx) require.Nil(t, err) // create server conn with 1 channel @@ -391,8 +424,9 @@ func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) ( onError = func(r interface{}) { chOnErr <- struct{}{} } + mconnServer := createMConnectionWithCallbacks(serverLogger, server, onReceive, onError) - err = mconnServer.Start() + err = mconnServer.Start(ctx) require.Nil(t, err) return mconnClient, mconnServer } @@ -408,8 +442,11 @@ func expectSend(ch chan struct{}) bool { } func TestMConnectionReadErrorBadEncoding(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + chOnErr := make(chan struct{}) - mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) client := mconnClient.conn @@ -421,8 +458,11 @@ func TestMConnectionReadErrorBadEncoding(t *testing.T) { } func TestMConnectionReadErrorUnknownChannel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + chOnErr := make(chan struct{}) - mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) msg := []byte("Ant-Man") @@ -440,7 +480,10 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(stopAll(t, mconnClient, mconnServer)) mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { @@ -474,8 +517,11 @@ func TestMConnectionReadErrorLongMessage(t *testing.T) { } func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + chOnErr := make(chan struct{}) - mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(stopAll(t, mconnClient, mconnServer)) // send msg with unknown msg type @@ -487,9 +533,11 @@ func TestMConnectionReadErrorUnknownMsgType(t *testing.T) { func TestMConnectionTrySend(t *testing.T) { server, client := NetPipe() t.Cleanup(closeAll(t, client, server)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() mconn := createTestMConnection(log.TestingLogger(), client) - err := mconn.Start() + err := mconn.Start(ctx) require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) @@ -535,7 +583,10 @@ func TestMConnectionChannelOverflow(t *testing.T) { chOnErr := make(chan struct{}) chOnRcv := make(chan struct{}) - mconnClient, mconnServer := newClientAndServerConnsForReadErrors(t, chOnErr) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mconnClient, mconnServer := newClientAndServerConnsForReadErrors(ctx, t, chOnErr) t.Cleanup(stopAll(t, mconnClient, mconnServer)) mconnServer.onReceive = func(chID ChannelID, msgBytes []byte) { diff --git a/internal/p2p/p2p_test.go b/internal/p2p/p2p_test.go index 642114a1d..d8657b774 100644 --- a/internal/p2p/p2p_test.go +++ b/internal/p2p/p2p_test.go @@ -1,8 +1,6 @@ package p2p_test import ( - "context" - "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/internal/p2p" @@ -13,7 +11,6 @@ import ( // Common setup for P2P tests. var ( - ctx = context.Background() chID = p2p.ChannelID(1) chDesc = &p2p.ChannelDescriptor{ ID: chID, diff --git a/internal/p2p/p2ptest/network.go b/internal/p2p/p2ptest/network.go index 0d92b2619..6ee253b3c 100644 --- a/internal/p2p/p2ptest/network.go +++ b/internal/p2p/p2ptest/network.go @@ -47,7 +47,7 @@ func (opts *NetworkOptions) setDefaults() { // MakeNetwork creates a test network with the given number of nodes and // connects them to each other. -func MakeNetwork(t *testing.T, opts NetworkOptions) *Network { +func MakeNetwork(ctx context.Context, t *testing.T, opts NetworkOptions) *Network { opts.setDefaults() logger := log.TestingLogger() network := &Network{ @@ -57,7 +57,7 @@ func MakeNetwork(t *testing.T, opts NetworkOptions) *Network { } for i := 0; i < opts.NumNodes; i++ { - node := network.MakeNode(t, opts.NodeOpts) + node := network.MakeNode(ctx, t, opts.NodeOpts) network.Nodes[node.NodeID] = node } @@ -221,7 +221,7 @@ type Node struct { // MakeNode creates a new Node configured for the network with a // running peer manager, but does not add it to the existing // network. Callers are responsible for updating peering relationships. -func (n *Network) MakeNode(t *testing.T, opts NodeOptions) *Node { +func (n *Network) MakeNode(ctx context.Context, t *testing.T, opts NodeOptions) *Node { privKey := ed25519.GenPrivKey() nodeID := types.NodeIDFromPubKey(privKey.PubKey()) nodeInfo := types.NodeInfo{ @@ -252,8 +252,9 @@ func (n *Network) MakeNode(t *testing.T, opts NodeOptions) *Node { transport.Endpoints(), p2p.RouterOptions{DialSleep: func(_ context.Context) {}}, ) + require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) t.Cleanup(func() { if router.IsRunning() { @@ -304,12 +305,12 @@ func (n *Node) MakeChannelNoCleanup( // MakePeerUpdates opens a peer update subscription, with automatic cleanup. // It checks that all updates have been consumed during cleanup. -func (n *Node) MakePeerUpdates(t *testing.T) *p2p.PeerUpdates { +func (n *Node) MakePeerUpdates(ctx context.Context, t *testing.T) *p2p.PeerUpdates { t.Helper() sub := n.PeerManager.Subscribe() t.Cleanup(func() { t.Helper() - RequireNoUpdates(t, sub) + RequireNoUpdates(ctx, t, sub) sub.Close() }) diff --git a/internal/p2p/p2ptest/require.go b/internal/p2p/p2ptest/require.go index a9fc16a34..106063bbd 100644 --- a/internal/p2p/p2ptest/require.go +++ b/internal/p2p/p2ptest/require.go @@ -1,6 +1,7 @@ package p2ptest import ( + "context" "testing" "time" @@ -95,11 +96,14 @@ func RequireSendReceive( } // RequireNoUpdates requires that a PeerUpdates subscription is empty. -func RequireNoUpdates(t *testing.T, peerUpdates *p2p.PeerUpdates) { +func RequireNoUpdates(ctx context.Context, t *testing.T, peerUpdates *p2p.PeerUpdates) { t.Helper() select { case update := <-peerUpdates.Updates(): - require.Fail(t, "unexpected peer updates", "got %v", update) + if ctx.Err() == nil { + require.Fail(t, "unexpected peer updates", "got %v", update) + } + case <-ctx.Done(): default: } } diff --git a/internal/p2p/peermanager_test.go b/internal/p2p/peermanager_test.go index 69c798d2d..28efe63dd 100644 --- a/internal/p2p/peermanager_test.go +++ b/internal/p2p/peermanager_test.go @@ -273,6 +273,9 @@ func TestPeerManager_Add(t *testing.T) { } func TestPeerManager_DialNext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) @@ -296,6 +299,9 @@ func TestPeerManager_DialNext(t *testing.T) { } func TestPeerManager_DialNext_Retry(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} options := p2p.PeerManagerOptions{ @@ -311,7 +317,7 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { // Do five dial retries (six dials total). The retry time should double for // each failure. At the forth retry, MaxRetryTime should kick in. - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + ctx, cancel = context.WithTimeout(ctx, 5*time.Second) defer cancel() for i := 0; i <= 5; i++ { @@ -342,6 +348,9 @@ func TestPeerManager_DialNext_Retry(t *testing.T) { } func TestPeerManager_DialNext_WakeOnAdd(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) @@ -356,7 +365,7 @@ func TestPeerManager_DialNext_WakeOnAdd(t *testing.T) { }() // This will block until peer is added above. - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err := peerManager.DialNext(ctx) require.NoError(t, err) @@ -364,6 +373,9 @@ func TestPeerManager_DialNext_WakeOnAdd(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{ MaxConnected: 1, }) @@ -395,7 +407,7 @@ func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { }() // This should make b available for dialing (not a, retries are disabled). - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err = peerManager.DialNext(ctx) require.NoError(t, err) @@ -403,6 +415,9 @@ func TestPeerManager_DialNext_WakeOnDialFailed(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + options := p2p.PeerManagerOptions{MinRetryTime: 200 * time.Millisecond} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), options) require.NoError(t, err) @@ -421,7 +436,7 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { // The retry timer should unblock DialNext and make a available again after // the retry time passes. - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err = peerManager.DialNext(ctx) require.NoError(t, err) @@ -430,6 +445,9 @@ func TestPeerManager_DialNext_WakeOnDialFailedRetry(t *testing.T) { } func TestPeerManager_DialNext_WakeOnDisconnected(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) @@ -450,7 +468,7 @@ func TestPeerManager_DialNext_WakeOnDisconnected(t *testing.T) { peerManager.Disconnected(a.NodeID) }() - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() dial, err = peerManager.DialNext(ctx) require.NoError(t, err) @@ -1289,6 +1307,9 @@ func TestPeerManager_Ready(t *testing.T) { // See TryEvictNext for most tests, this just tests blocking behavior. func TestPeerManager_EvictNext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) @@ -1322,6 +1343,9 @@ func TestPeerManager_EvictNext(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} peerManager, err := p2p.NewPeerManager(selfID, dbm.NewMemDB(), p2p.PeerManagerOptions{}) @@ -1340,7 +1364,7 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { }() // This will block until peer errors above. - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) @@ -1348,6 +1372,9 @@ func TestPeerManager_EvictNext_WakeOnError(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1378,7 +1405,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { }() // This will block until peer is upgraded above. - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) @@ -1386,6 +1413,9 @@ func TestPeerManager_EvictNext_WakeOnUpgradeDialed(t *testing.T) { } func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "memory", NodeID: types.NodeID(strings.Repeat("b", 40))} @@ -1410,7 +1440,7 @@ func TestPeerManager_EvictNext_WakeOnUpgradeAccepted(t *testing.T) { }() // This will block until peer is upgraded above. - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) defer cancel() evict, err := peerManager.EvictNext(ctx) require.NoError(t, err) diff --git a/internal/p2p/pex/reactor.go b/internal/p2p/pex/reactor.go index 9a5535a91..f6fcad5e1 100644 --- a/internal/p2p/pex/reactor.go +++ b/internal/p2p/pex/reactor.go @@ -1,6 +1,7 @@ package pex import ( + "context" "fmt" "runtime/debug" "sync" @@ -139,7 +140,7 @@ func NewReactor( // envelopes on each. In addition, it also listens for peer updates and handles // messages on that p2p channel accordingly. The caller must be sure to execute // OnStop to ensure the outbound p2p Channels are closed. -func (r *Reactor) OnStart() error { +func (r *Reactor) OnStart(ctx context.Context) error { go r.processPexCh() go r.processPeerUpdates() return nil diff --git a/internal/p2p/pex/reactor_test.go b/internal/p2p/pex/reactor_test.go index 04b347cb5..63b182fc0 100644 --- a/internal/p2p/pex/reactor_test.go +++ b/internal/p2p/pex/reactor_test.go @@ -1,6 +1,7 @@ package pex_test import ( + "context" "strings" "testing" "time" @@ -29,13 +30,15 @@ const ( ) func TestReactorBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // start a network with one mock reactor and one "real" reactor - testNet := setupNetwork(t, testOptions{ + testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, TotalNodes: 2, }) testNet.connectAll(t) - testNet.start(t) + testNet.start(ctx, t) // assert that the mock node receives a request from the real node testNet.listenForRequest(t, secondNode, firstNode, shortWait) @@ -47,14 +50,17 @@ func TestReactorBasic(t *testing.T) { } func TestReactorConnectFullNetwork(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 4, }) // make every node be only connected with one other node (it actually ends up // being two because of two way connections but oh well) testNet.connectN(t, 1) - testNet.start(t) + testNet.start(ctx, t) // assert that all nodes add each other in the network for idx := 0; idx < len(testNet.nodes); idx++ { @@ -63,7 +69,10 @@ func TestReactorConnectFullNetwork(t *testing.T) { } func TestReactorSendsRequestsTooOften(t *testing.T) { - r := setupSingle(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := setupSingle(ctx, t) badNode := newNodeID(t, "b") @@ -90,12 +99,15 @@ func TestReactorSendsRequestsTooOften(t *testing.T) { } func TestReactorSendsResponseWithoutRequest(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, TotalNodes: 3, }) testNet.connectAll(t) - testNet.start(t) + testNet.start(ctx, t) // firstNode sends the secondNode an unrequested response // NOTE: secondNode will send a request by default during startup so we send @@ -108,14 +120,17 @@ func TestReactorSendsResponseWithoutRequest(t *testing.T) { } func TestReactorNeverSendsTooManyPeers(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ MockNodes: 1, TotalNodes: 2, }) testNet.connectAll(t) - testNet.start(t) + testNet.start(ctx, t) - testNet.addNodes(t, 110) + testNet.addNodes(ctx, t, 110) nodes := make([]int, 110) for i := 0; i < len(nodes); i++ { nodes[i] = i + 2 @@ -128,7 +143,10 @@ func TestReactorNeverSendsTooManyPeers(t *testing.T) { } func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { - r := setupSingle(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := setupSingle(ctx, t) peer := p2p.NodeAddress{Protocol: p2p.MemoryProtocol, NodeID: randomNodeID(t)} added, err := r.manager.Add(peer) require.NoError(t, err) @@ -172,14 +190,17 @@ func TestReactorErrorsOnReceivingTooManyPeers(t *testing.T) { } func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 8, MaxPeers: 4, MaxConnected: 3, BufferSize: 8, }) testNet.connectN(t, 1) - testNet.start(t) + testNet.start(ctx, t) // test that all nodes reach full capacity for _, nodeID := range testNet.nodes { @@ -191,14 +212,17 @@ func TestReactorSmallPeerStoreInALargeNetwork(t *testing.T) { } func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 3, MaxPeers: 25, MaxConnected: 25, BufferSize: 5, }) testNet.connectN(t, 1) - testNet.start(t) + testNet.start(ctx, t) // assert that all nodes add each other in the network for idx := 0; idx < len(testNet.nodes); idx++ { @@ -207,12 +231,15 @@ func TestReactorLargePeerStoreInASmallNetwork(t *testing.T) { } func TestReactorWithNetworkGrowth(t *testing.T) { - testNet := setupNetwork(t, testOptions{ + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testNet := setupNetwork(ctx, t, testOptions{ TotalNodes: 5, BufferSize: 5, }) testNet.connectAll(t) - testNet.start(t) + testNet.start(ctx, t) // assert that all nodes add each other in the network for idx := 0; idx < len(testNet.nodes); idx++ { @@ -220,10 +247,10 @@ func TestReactorWithNetworkGrowth(t *testing.T) { } // now we inject 10 more nodes - testNet.addNodes(t, 10) + testNet.addNodes(ctx, t, 10) for i := 5; i < testNet.total; i++ { node := testNet.nodes[i] - require.NoError(t, testNet.reactors[node].Start()) + require.NoError(t, testNet.reactors[node].Start(ctx)) require.True(t, testNet.reactors[node].IsRunning()) // we connect all new nodes to a single entry point and check that the // node can distribute the addresses to all the others @@ -247,7 +274,7 @@ type singleTestReactor struct { manager *p2p.PeerManager } -func setupSingle(t *testing.T) *singleTestReactor { +func setupSingle(ctx context.Context, t *testing.T) *singleTestReactor { t.Helper() nodeID := newNodeID(t, "a") chBuf := 2 @@ -268,14 +295,11 @@ func setupSingle(t *testing.T) *singleTestReactor { require.NoError(t, err) reactor := pex.NewReactor(log.TestingLogger(), peerManager, pexCh, peerUpdates) - require.NoError(t, reactor.Start()) + require.NoError(t, reactor.Start(ctx)) t.Cleanup(func() { - err := reactor.Stop() - if err != nil { - t.Fatal(err) - } pexCh.Close() peerUpdates.Close() + reactor.Wait() }) return &singleTestReactor{ @@ -315,7 +339,7 @@ type testOptions struct { // setup setups a test suite with a network of nodes. Mocknodes represent the // hollow nodes that the test can listen and send on -func setupNetwork(t *testing.T, opts testOptions) *reactorTestSuite { +func setupNetwork(ctx context.Context, t *testing.T, opts testOptions) *reactorTestSuite { t.Helper() require.Greater(t, opts.TotalNodes, opts.MockNodes) @@ -335,7 +359,7 @@ func setupNetwork(t *testing.T, opts testOptions) *reactorTestSuite { rts := &reactorTestSuite{ logger: log.TestingLogger().With("testCase", t.Name()), - network: p2ptest.MakeNetwork(t, networkOpts), + network: p2ptest.MakeNetwork(ctx, t, networkOpts), reactors: make(map[types.NodeID]*pex.Reactor, realNodes), pexChannels: make(map[types.NodeID]*p2p.Channel, opts.TotalNodes), peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, opts.TotalNodes), @@ -375,7 +399,7 @@ func setupNetwork(t *testing.T, opts testOptions) *reactorTestSuite { t.Cleanup(func() { for nodeID, reactor := range rts.reactors { if reactor.IsRunning() { - require.NoError(t, reactor.Stop()) + reactor.Wait() require.False(t, reactor.IsRunning()) } rts.pexChannels[nodeID].Close() @@ -391,20 +415,20 @@ func setupNetwork(t *testing.T, opts testOptions) *reactorTestSuite { } // starts up the pex reactors for each node -func (r *reactorTestSuite) start(t *testing.T) { +func (r *reactorTestSuite) start(ctx context.Context, t *testing.T) { t.Helper() for _, reactor := range r.reactors { - require.NoError(t, reactor.Start()) + require.NoError(t, reactor.Start(ctx)) require.True(t, reactor.IsRunning()) } } -func (r *reactorTestSuite) addNodes(t *testing.T, nodes int) { +func (r *reactorTestSuite) addNodes(ctx context.Context, t *testing.T, nodes int) { t.Helper() for i := 0; i < nodes; i++ { - node := r.network.MakeNode(t, p2ptest.NodeOptions{ + node := r.network.MakeNode(ctx, t, p2ptest.NodeOptions{ MaxPeers: r.opts.MaxPeers, MaxConnected: r.opts.MaxConnected, }) diff --git a/internal/p2p/router.go b/internal/p2p/router.go index efd29c0d4..29646e327 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -1023,7 +1023,7 @@ func (r *Router) NodeInfo() types.NodeInfo { } // OnStart implements service.Service. -func (r *Router) OnStart() error { +func (r *Router) OnStart(ctx context.Context) error { for _, transport := range r.transports { for _, endpoint := range r.endpoints { if err := transport.Listen(endpoint); err != nil { diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 77c6f768e..c77e9e44d 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -44,10 +44,13 @@ func echoReactor(channel *p2p.Channel) { } func TestRouter_Network(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Cleanup(leaktest.Check(t)) // Create a test network and open a channel where all peers run echoReactor. - network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 8}) + network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 8}) local := network.RandomNode() peers := network.Peers(local.NodeID) channels := network.MakeChannels(t, chDesc) @@ -114,10 +117,11 @@ func TestRouter_Channel_Basic(t *testing.T) { ) require.NoError(t, err) - require.NoError(t, router.Start()) - t.Cleanup(func() { - require.NoError(t, router.Stop()) - }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, router.Start(ctx)) + t.Cleanup(router.Wait) // Opening a channel should work. channel, err := router.OpenChannel(chDesc) @@ -158,10 +162,13 @@ func TestRouter_Channel_Basic(t *testing.T) { // Channel tests are hairy to mock, so we use an in-memory network instead. func TestRouter_Channel_SendReceive(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Cleanup(leaktest.Check(t)) // Create a test network and open a channel on all nodes. - network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 3}) + network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) ids := network.NodeIDs() aID, bID, cID := ids[0], ids[1], ids[2] @@ -219,8 +226,11 @@ func TestRouter_Channel_SendReceive(t *testing.T) { func TestRouter_Channel_Broadcast(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Create a test network and open a channel on all nodes. - network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 4}) + network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 4}) ids := network.NodeIDs() aID, bID, cID, dID := ids[0], ids[1], ids[2], ids[3] @@ -247,8 +257,11 @@ func TestRouter_Channel_Broadcast(t *testing.T) { func TestRouter_Channel_Wrapper(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Create a test network and open a channel on all nodes. - network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 2}) + network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 2}) ids := network.NodeIDs() aID, bID := ids[0], ids[1] @@ -314,8 +327,11 @@ func (w *wrapperMessage) Unwrap() (proto.Message, error) { func TestRouter_Channel_Error(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Create a test network and open a channel on all nodes. - network := p2ptest.MakeNetwork(t, p2ptest.NetworkOptions{NumNodes: 3}) + network := p2ptest.MakeNetwork(ctx, t, p2ptest.NetworkOptions{NumNodes: 3}) network.Start(t) ids := network.NodeIDs() @@ -324,7 +340,7 @@ func TestRouter_Channel_Error(t *testing.T) { a := channels[aID] // Erroring b should cause it to be disconnected. It will reconnect shortly after. - sub := network.Nodes[aID].MakePeerUpdates(t) + sub := network.Nodes[aID].MakePeerUpdates(ctx, t) p2ptest.RequireError(t, a, p2p.PeerError{NodeID: bID, Err: errors.New("boom")}) p2ptest.RequireUpdates(t, sub, []p2p.PeerUpdate{ {NodeID: bID, Status: p2p.PeerStatusDown}, @@ -353,9 +369,16 @@ func TestRouter_AcceptPeers(t *testing.T) { false, }, } + + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + t.Cleanup(leaktest.Check(t)) // Set up a mock transport that handshakes. @@ -398,7 +421,7 @@ func TestRouter_AcceptPeers(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) if tc.ok { p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ @@ -427,6 +450,9 @@ func TestRouter_AcceptPeers(t *testing.T) { func TestRouter_AcceptPeers_Error(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Set up a mock transport that returns an error, which should prevent // the router from calling Accept again. mockTransport := &mocks.Transport{} @@ -452,7 +478,7 @@ func TestRouter_AcceptPeers_Error(t *testing.T) { ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) time.Sleep(time.Second) require.NoError(t, router.Stop()) @@ -487,7 +513,10 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { ) require.NoError(t, err) - require.NoError(t, router.Start()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, router.Start(ctx)) time.Sleep(time.Second) require.NoError(t, router.Stop()) @@ -497,6 +526,9 @@ func TestRouter_AcceptPeers_ErrorEOF(t *testing.T) { func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Set up a mock transport that returns a connection that blocks during the // handshake. It should be able to accept several of these in parallel, i.e. // a single connection can't halt other connections being accepted. @@ -535,7 +567,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) require.Eventually(t, func() bool { return len(acceptCh) == 3 @@ -574,10 +606,16 @@ func TestRouter_DialPeers(t *testing.T) { false, }, } + + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(bctx) + defer cancel() address := p2p.NodeAddress{Protocol: "mock", NodeID: tc.dialID} endpoint := p2p.Endpoint{Protocol: "mock", Path: string(tc.dialID)} @@ -635,7 +673,7 @@ func TestRouter_DialPeers(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) if tc.ok { p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ @@ -664,6 +702,9 @@ func TestRouter_DialPeers(t *testing.T) { func TestRouter_DialPeers_Parallel(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + a := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("a", 40))} b := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("b", 40))} c := p2p.NodeAddress{Protocol: "mock", NodeID: types.NodeID(strings.Repeat("c", 40))} @@ -729,7 +770,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) require.Eventually(t, func() bool { @@ -750,6 +791,9 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { func TestRouter_EvictPeers(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Set up a mock transport that we can evict. closeCh := make(chan time.Time) closeOnce := sync.Once{} @@ -792,7 +836,7 @@ func TestRouter_EvictPeers(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) // Wait for the mock peer to connect, then evict it by reporting an error. p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ @@ -815,6 +859,8 @@ func TestRouter_EvictPeers(t *testing.T) { func TestRouter_ChannelCompatability(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() incompatiblePeer := types.NodeInfo{ NodeID: peerID, @@ -854,7 +900,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) time.Sleep(1 * time.Second) require.NoError(t, router.Stop()) require.Empty(t, peerManager.Peers()) @@ -865,6 +911,8 @@ func TestRouter_ChannelCompatability(t *testing.T) { func TestRouter_DontSendOnInvalidChannel(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() peer := types.NodeInfo{ NodeID: peerID, @@ -909,7 +957,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { p2p.RouterOptions{}, ) require.NoError(t, err) - require.NoError(t, router.Start()) + require.NoError(t, router.Start(ctx)) p2ptest.RequireUpdate(t, sub, p2p.PeerUpdate{ NodeID: peerInfo.NodeID, diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 736f5360a..0580ce1bf 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -291,7 +291,7 @@ func (c *mConnConnection) Handshake( } c.mconn = mconn c.logger = mconn.Logger - if err = c.mconn.Start(); err != nil { + if err = c.mconn.Start(ctx); err != nil { return types.NodeInfo{}, nil, err } return peerInfo, peerKey, nil diff --git a/internal/p2p/transport_mconn_test.go b/internal/p2p/transport_mconn_test.go index d33438109..4d9a945cb 100644 --- a/internal/p2p/transport_mconn_test.go +++ b/internal/p2p/transport_mconn_test.go @@ -1,6 +1,7 @@ package p2p_test import ( + "context" "io" "net" "testing" @@ -58,6 +59,9 @@ func TestMConnTransport_AcceptBeforeListen(t *testing.T) { } func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + transport := p2p.NewMConnTransport( log.TestingLogger(), conn.DefaultMConnConfig(), @@ -124,6 +128,9 @@ func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { } func TestMConnTransport_Listen(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + testcases := []struct { endpoint p2p.Endpoint ok bool @@ -145,6 +152,9 @@ func TestMConnTransport_Listen(t *testing.T) { t.Run(tc.endpoint.String(), func(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel = context.WithCancel(ctx) + defer cancel() + transport := p2p.NewMConnTransport( log.TestingLogger(), conn.DefaultMConnConfig(), @@ -185,6 +195,9 @@ func TestMConnTransport_Listen(t *testing.T) { go func() { // Dialing the endpoint should work. var err error + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peerConn, err = transport.Dial(ctx, endpoint) require.NoError(t, err) close(dialedChan) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index cdfb57c70..a53be251d 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -25,20 +25,26 @@ var testTransports = map[string]transportFactory{} // withTransports is a test helper that runs a test against all transports // registered in testTransports. -func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { +func withTransports(ctx context.Context, t *testing.T, tester func(context.Context, *testing.T, transportFactory)) { t.Helper() for name, transportFactory := range testTransports { transportFactory := transportFactory t.Run(name, func(t *testing.T) { t.Cleanup(leaktest.Check(t)) - tester(t, transportFactory) + tctx, cancel := context.WithCancel(ctx) + defer cancel() + + tester(tctx, t, transportFactory) }) } } func TestTransport_AcceptClose(t *testing.T) { // Just test accept unblock on close, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) // In-progress Accept should error on concurrent close. @@ -75,7 +81,10 @@ func TestTransport_DialEndpoints(t *testing.T) { {[]byte{1, 2, 3, 4, 5}, false}, } - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) endpoints := a.Endpoints() require.NotEmpty(t, endpoints) @@ -149,8 +158,11 @@ func TestTransport_DialEndpoints(t *testing.T) { } func TestTransport_Dial(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Most just tests dial failures, happy path is tested widely elsewhere. - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) @@ -190,7 +202,10 @@ func TestTransport_Dial(t *testing.T) { } func TestTransport_Endpoints(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) @@ -214,7 +229,10 @@ func TestTransport_Endpoints(t *testing.T) { } func TestTransport_Protocols(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) protocols := a.Protocols() endpoints := a.Endpoints() @@ -228,17 +246,23 @@ func TestTransport_Protocols(t *testing.T) { } func TestTransport_String(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) require.NotEmpty(t, a.String()) }) } func TestConnection_Handshake(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, ba := dialAccept(t, a, b) + ab, ba := dialAccept(ctx, t, a, b) // A handshake should pass the given keys and NodeInfo. aKey := ed25519.GenPrivKey() @@ -270,7 +294,10 @@ func TestConnection_Handshake(t *testing.T) { assert.Equal(t, aInfo, peerInfo) assert.Equal(t, aKey.PubKey(), peerKey) } - errCh <- err + select { + case errCh <- err: + case <-ctx.Done(): + } }() peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey) @@ -283,12 +310,15 @@ func TestConnection_Handshake(t *testing.T) { } func TestConnection_HandshakeCancel(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) // Handshake should error on context cancellation. - ab, ba := dialAccept(t, a, b) + ab, ba := dialAccept(ctx, t, a, b) timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) cancel() _, _, err := ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) @@ -298,7 +328,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { _ = ba.Close() // Handshake should error on context timeout. - ab, ba = dialAccept(t, a, b) + ab, ba = dialAccept(ctx, t, a, b) timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) @@ -310,10 +340,13 @@ func TestConnection_HandshakeCancel(t *testing.T) { } func TestConnection_FlushClose(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, _ := dialAcceptHandshake(t, a, b) + ab, _ := dialAcceptHandshake(ctx, t, a, b) err := ab.Close() require.NoError(t, err) @@ -329,10 +362,13 @@ func TestConnection_FlushClose(t *testing.T) { } func TestConnection_LocalRemoteEndpoint(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, ba := dialAcceptHandshake(t, a, b) + ab, ba := dialAcceptHandshake(ctx, t, a, b) // Local and remote connection endpoints correspond to each other. require.NotEmpty(t, ab.LocalEndpoint()) @@ -343,10 +379,13 @@ func TestConnection_LocalRemoteEndpoint(t *testing.T) { } func TestConnection_SendReceive(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, ba := dialAcceptHandshake(t, a, b) + ab, ba := dialAcceptHandshake(ctx, t, a, b) // Can send and receive a to b. err := ab.SendMessage(chID, []byte("foo")) @@ -402,10 +441,13 @@ func TestConnection_SendReceive(t *testing.T) { } func TestConnection_String(t *testing.T) { - withTransports(t, func(t *testing.T, makeTransport transportFactory) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, _ := dialAccept(t, a, b) + ab, _ := dialAccept(ctx, t, a, b) require.NotEmpty(t, ab.String()) }) } @@ -552,7 +594,7 @@ func TestEndpoint_Validate(t *testing.T) { // dialAccept is a helper that dials b from a and returns both sides of the // connection. -func dialAccept(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { +func dialAccept(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { t.Helper() endpoints := b.Endpoints() @@ -585,13 +627,10 @@ func dialAccept(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connectio // dialAcceptHandshake is a helper that dials and handshakes b from a and // returns both sides of the connection. -func dialAcceptHandshake(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { +func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { t.Helper() - ab, ba := dialAccept(t, a, b) - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() + ab, ba := dialAccept(ctx, t, a, b) errCh := make(chan error, 1) go func() { diff --git a/internal/p2p/trust/metric.go b/internal/p2p/trust/metric.go index aa0ff5298..8ad814f10 100644 --- a/internal/p2p/trust/metric.go +++ b/internal/p2p/trust/metric.go @@ -4,6 +4,7 @@ package trust import ( + "context" "math" "time" @@ -109,8 +110,8 @@ func NewMetricWithConfig(tmc MetricConfig) *Metric { } // OnStart implements Service -func (tm *Metric) OnStart() error { - if err := tm.BaseService.OnStart(); err != nil { +func (tm *Metric) OnStart(ctx context.Context) error { + if err := tm.BaseService.OnStart(ctx); err != nil { return err } go tm.processRequests() diff --git a/internal/p2p/trust/metric_test.go b/internal/p2p/trust/metric_test.go index 65caf38a2..b7d19da23 100644 --- a/internal/p2p/trust/metric_test.go +++ b/internal/p2p/trust/metric_test.go @@ -1,6 +1,7 @@ package trust import ( + "context" "testing" "time" @@ -9,8 +10,11 @@ import ( ) func TestTrustMetricScores(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tm := NewMetric() - err := tm.Start() + err := tm.Start(ctx) require.NoError(t, err) // Perfect score @@ -27,6 +31,9 @@ func TestTrustMetricScores(t *testing.T) { } func TestTrustMetricConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // 7 days window := time.Minute * 60 * 24 * 7 config := MetricConfig{ @@ -35,7 +42,7 @@ func TestTrustMetricConfig(t *testing.T) { } tm := NewMetricWithConfig(config) - err := tm.Start() + err := tm.Start(ctx) require.NoError(t, err) // The max time intervals should be the TrackingWindow / IntervalLen @@ -52,7 +59,7 @@ func TestTrustMetricConfig(t *testing.T) { config.ProportionalWeight = 0.3 config.IntegralWeight = 0.7 tm = NewMetricWithConfig(config) - err = tm.Start() + err = tm.Start(ctx) require.NoError(t, err) // These weights should be equal to our custom values @@ -74,12 +81,15 @@ func TestTrustMetricCopyNilPointer(t *testing.T) { // XXX: This test fails non-deterministically //nolint:unused,deadcode func _TestTrustMetricStopPause(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // The TestTicker will provide manual control over // the passing of time within the metric tt := NewTestTicker() tm := NewMetric() tm.SetTicker(tt) - err := tm.Start() + err := tm.Start(ctx) require.NoError(t, err) // Allow some time intervals to pass and pause tt.NextTick() diff --git a/internal/p2p/trust/store.go b/internal/p2p/trust/store.go index 158354e14..a01cbab2e 100644 --- a/internal/p2p/trust/store.go +++ b/internal/p2p/trust/store.go @@ -4,6 +4,7 @@ package trust import ( + "context" "encoding/json" "fmt" "time" @@ -51,15 +52,15 @@ func NewTrustMetricStore(db dbm.DB, tmc MetricConfig, logger log.Logger) *Metric } // OnStart implements Service -func (tms *MetricStore) OnStart() error { - if err := tms.BaseService.OnStart(); err != nil { +func (tms *MetricStore) OnStart(ctx context.Context) error { + if err := tms.BaseService.OnStart(ctx); err != nil { return err } tms.mtx.Lock() defer tms.mtx.Unlock() - tms.loadFromDB() + tms.loadFromDB(ctx) go tms.saveRoutine() return nil } @@ -103,7 +104,7 @@ func (tms *MetricStore) AddPeerTrustMetric(key string, tm *Metric) { } // GetPeerTrustMetric returns a trust metric by peer key -func (tms *MetricStore) GetPeerTrustMetric(key string) *Metric { +func (tms *MetricStore) GetPeerTrustMetric(ctx context.Context, key string) *Metric { tms.mtx.Lock() defer tms.mtx.Unlock() @@ -111,7 +112,7 @@ func (tms *MetricStore) GetPeerTrustMetric(key string) *Metric { if !ok { // If the metric is not available, we will create it tm = NewMetricWithConfig(tms.config) - if err := tm.Start(); err != nil { + if err := tm.Start(ctx); err != nil { tms.Logger.Error("unable to start metric store", "error", err) } // The metric needs to be in the map @@ -152,7 +153,7 @@ func (tms *MetricStore) size() int { // Loads the history data for all peers from the store DB // cmn.Panics if file is corrupt -func (tms *MetricStore) loadFromDB() bool { +func (tms *MetricStore) loadFromDB(ctx context.Context) bool { // Obtain the history data we have so far bytes, err := tms.db.Get(trustMetricKey) if err != nil { @@ -173,7 +174,7 @@ func (tms *MetricStore) loadFromDB() bool { for key, p := range peers { tm := NewMetricWithConfig(tms.config) - if err := tm.Start(); err != nil { + if err := tm.Start(ctx); err != nil { tms.Logger.Error("unable to start metric", "error", err) } tm.Init(p) diff --git a/internal/p2p/trust/store_test.go b/internal/p2p/trust/store_test.go index d1420b1dc..a6178459f 100644 --- a/internal/p2p/trust/store_test.go +++ b/internal/p2p/trust/store_test.go @@ -4,6 +4,7 @@ package trust import ( + "context" "fmt" "testing" @@ -15,6 +16,9 @@ import ( ) func TestTrustMetricStoreSaveLoad(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dir := t.TempDir() logger := log.TestingLogger() @@ -26,7 +30,7 @@ func TestTrustMetricStoreSaveLoad(t *testing.T) { store.saveToDB() // Load the data from the file store = NewTrustMetricStore(historyDB, DefaultConfig(), logger) - err = store.Start() + err = store.Start(ctx) require.NoError(t, err) // Make sure we still have 0 entries assert.Zero(t, store.Size()) @@ -44,7 +48,7 @@ func TestTrustMetricStoreSaveLoad(t *testing.T) { tm := NewMetric() tm.SetTicker(tt[i]) - err = tm.Start() + err = tm.Start(ctx) require.NoError(t, err) store.AddPeerTrustMetric(key, tm) @@ -65,7 +69,7 @@ func TestTrustMetricStoreSaveLoad(t *testing.T) { // Load the data from the DB store = NewTrustMetricStore(historyDB, DefaultConfig(), logger) - err = store.Start() + err = store.Start(ctx) require.NoError(t, err) // Check that we still have 100 peers with imperfect trust values @@ -79,6 +83,9 @@ func TestTrustMetricStoreSaveLoad(t *testing.T) { } func TestTrustMetricStoreConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + historyDB, err := dbm.NewDB("", "memdb", "") require.NoError(t, err) @@ -91,11 +98,11 @@ func TestTrustMetricStoreConfig(t *testing.T) { // Create a store with custom config store := NewTrustMetricStore(historyDB, config, logger) - err = store.Start() + err = store.Start(ctx) require.NoError(t, err) // Have the store make us a metric with the config - tm := store.GetPeerTrustMetric("TestKey") + tm := store.GetPeerTrustMetric(ctx, "TestKey") // Check that the options made it to the metric assert.Equal(t, 0.5, tm.proportionalWeight) @@ -105,18 +112,21 @@ func TestTrustMetricStoreConfig(t *testing.T) { } func TestTrustMetricStoreLookup(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + historyDB, err := dbm.NewDB("", "memdb", "") require.NoError(t, err) store := NewTrustMetricStore(historyDB, DefaultConfig(), log.TestingLogger()) - err = store.Start() + err = store.Start(ctx) require.NoError(t, err) // Create 100 peers in the trust metric store for i := 0; i < 100; i++ { key := fmt.Sprintf("peer_%d", i) - store.GetPeerTrustMetric(key) + store.GetPeerTrustMetric(ctx, key) // Check that the trust metric was successfully entered ktm := store.peerMetrics[key] @@ -128,16 +138,19 @@ func TestTrustMetricStoreLookup(t *testing.T) { } func TestTrustMetricStorePeerScore(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + historyDB, err := dbm.NewDB("", "memdb", "") require.NoError(t, err) store := NewTrustMetricStore(historyDB, DefaultConfig(), log.TestingLogger()) - err = store.Start() + err = store.Start(ctx) require.NoError(t, err) key := "TestKey" - tm := store.GetPeerTrustMetric(key) + tm := store.GetPeerTrustMetric(ctx, key) // This peer is innocent so far first := tm.TrustScore() @@ -156,7 +169,7 @@ func TestTrustMetricStorePeerScore(t *testing.T) { store.PeerDisconnected(key) // We will remember our experiences with this peer - tm = store.GetPeerTrustMetric(key) + tm = store.GetPeerTrustMetric(ctx, key) assert.NotEqual(t, 100, tm.TrustScore()) err = store.Stop() require.NoError(t, err) diff --git a/internal/proxy/app_conn_test.go b/internal/proxy/app_conn_test.go index 4b4abe607..5eb810bd6 100644 --- a/internal/proxy/app_conn_test.go +++ b/internal/proxy/app_conn_test.go @@ -51,16 +51,15 @@ func TestEcho(t *testing.T) { logger := log.TestingLogger() clientCreator := abciclient.NewRemoteCreator(logger, sockPath, SOCKET, true) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) - if err := s.Start(); err != nil { + if err := s.Start(ctx); err != nil { t.Fatalf("Error starting socket server: %v", err.Error()) } - t.Cleanup(func() { - if err := s.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); s.Wait() }) // Start client cli, err := clientCreator(logger.With("module", "abci-client")) @@ -68,14 +67,13 @@ func TestEcho(t *testing.T) { t.Fatalf("Error creating ABCI client: %v", err.Error()) } - if err := cli.Start(); err != nil { + if err := cli.Start(ctx); err != nil { t.Fatalf("Error starting ABCI client: %v", err.Error()) } proxy := newAppConnTest(cli) t.Log("Connected") - ctx := context.Background() for i := 0; i < 1000; i++ { _, err = proxy.EchoAsync(ctx, fmt.Sprintf("echo-%v", i)) if err != nil { @@ -99,16 +97,15 @@ func BenchmarkEcho(b *testing.B) { logger := log.TestingLogger() clientCreator := abciclient.NewRemoteCreator(logger, sockPath, SOCKET, true) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) - if err := s.Start(); err != nil { + if err := s.Start(ctx); err != nil { b.Fatalf("Error starting socket server: %v", err.Error()) } - b.Cleanup(func() { - if err := s.Stop(); err != nil { - b.Error(err) - } - }) + b.Cleanup(func() { cancel(); s.Wait() }) // Start client cli, err := clientCreator(logger.With("module", "abci-client")) @@ -116,7 +113,7 @@ func BenchmarkEcho(b *testing.B) { b.Fatalf("Error creating ABCI client: %v", err.Error()) } - if err := cli.Start(); err != nil { + if err := cli.Start(ctx); err != nil { b.Fatalf("Error starting ABCI client: %v", err.Error()) } @@ -125,7 +122,6 @@ func BenchmarkEcho(b *testing.B) { echoString := strings.Repeat(" ", 200) b.StartTimer() // Start benchmarking tests - ctx := context.Background() for i := 0; i < b.N; i++ { _, err = proxy.EchoAsync(ctx, echoString) if err != nil { @@ -152,16 +148,15 @@ func TestInfo(t *testing.T) { logger := log.TestingLogger() clientCreator := abciclient.NewRemoteCreator(logger, sockPath, SOCKET, true) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Start server s := server.NewSocketServer(logger.With("module", "abci-server"), sockPath, kvstore.NewApplication()) - if err := s.Start(); err != nil { + if err := s.Start(ctx); err != nil { t.Fatalf("Error starting socket server: %v", err.Error()) } - t.Cleanup(func() { - if err := s.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); s.Wait() }) // Start client cli, err := clientCreator(logger.With("module", "abci-client")) @@ -169,7 +164,7 @@ func TestInfo(t *testing.T) { t.Fatalf("Error creating ABCI client: %v", err.Error()) } - if err := cli.Start(); err != nil { + if err := cli.Start(ctx); err != nil { t.Fatalf("Error starting ABCI client: %v", err.Error()) } diff --git a/internal/proxy/multi_app_conn.go b/internal/proxy/multi_app_conn.go index 62862d66e..5e5920f24 100644 --- a/internal/proxy/multi_app_conn.go +++ b/internal/proxy/multi_app_conn.go @@ -1,6 +1,8 @@ package proxy import ( + "context" + "errors" "fmt" "os" "syscall" @@ -51,14 +53,22 @@ type multiAppConn struct { queryConn AppConnQuery snapshotConn AppConnSnapshot - consensusConnClient abciclient.Client - mempoolConnClient abciclient.Client - queryConnClient abciclient.Client - snapshotConnClient abciclient.Client + consensusConnClient stoppableClient + mempoolConnClient stoppableClient + queryConnClient stoppableClient + snapshotConnClient stoppableClient clientCreator abciclient.Creator } +// TODO: this is a totally internal and quasi permanent shim for +// clients. eventually we can have a single client and have some kind +// of reasonable lifecycle witout needing an explicit stop method. +type stoppableClient interface { + abciclient.Client + Stop() error +} + // NewMultiAppConn makes all necessary abci connections to the application. func NewMultiAppConn(clientCreator abciclient.Creator, logger log.Logger, metrics *Metrics) AppConns { multiAppConn := &multiAppConn{ @@ -85,36 +95,36 @@ func (app *multiAppConn) Snapshot() AppConnSnapshot { return app.snapshotConn } -func (app *multiAppConn) OnStart() error { - c, err := app.abciClientFor(connQuery) +func (app *multiAppConn) OnStart(ctx context.Context) error { + c, err := app.abciClientFor(ctx, connQuery) if err != nil { return err } - app.queryConnClient = c + app.queryConnClient = c.(stoppableClient) app.queryConn = NewAppConnQuery(c, app.metrics) - c, err = app.abciClientFor(connSnapshot) + c, err = app.abciClientFor(ctx, connSnapshot) if err != nil { app.stopAllClients() return err } - app.snapshotConnClient = c + app.snapshotConnClient = c.(stoppableClient) app.snapshotConn = NewAppConnSnapshot(c, app.metrics) - c, err = app.abciClientFor(connMempool) + c, err = app.abciClientFor(ctx, connMempool) if err != nil { app.stopAllClients() return err } - app.mempoolConnClient = c + app.mempoolConnClient = c.(stoppableClient) app.mempoolConn = NewAppConnMempool(c, app.metrics) - c, err = app.abciClientFor(connConsensus) + c, err = app.abciClientFor(ctx, connConsensus) if err != nil { app.stopAllClients() return err } - app.consensusConnClient = c + app.consensusConnClient = c.(stoppableClient) app.consensusConn = NewAppConnConsensus(c, app.metrics) // Kill Tendermint if the ABCI application crashes. @@ -160,34 +170,42 @@ func (app *multiAppConn) killTMOnClientError() { func (app *multiAppConn) stopAllClients() { if app.consensusConnClient != nil { if err := app.consensusConnClient.Stop(); err != nil { - app.Logger.Error("error while stopping consensus client", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + app.Logger.Error("error while stopping consensus client", "error", err) + } } } if app.mempoolConnClient != nil { if err := app.mempoolConnClient.Stop(); err != nil { - app.Logger.Error("error while stopping mempool client", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + app.Logger.Error("error while stopping mempool client", "error", err) + } } } if app.queryConnClient != nil { if err := app.queryConnClient.Stop(); err != nil { - app.Logger.Error("error while stopping query client", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + app.Logger.Error("error while stopping query client", "error", err) + } } } if app.snapshotConnClient != nil { if err := app.snapshotConnClient.Stop(); err != nil { - app.Logger.Error("error while stopping snapshot client", "error", err) + if !errors.Is(err, service.ErrAlreadyStopped) { + app.Logger.Error("error while stopping snapshot client", "error", err) + } } } } -func (app *multiAppConn) abciClientFor(conn string) (abciclient.Client, error) { +func (app *multiAppConn) abciClientFor(ctx context.Context, conn string) (abciclient.Client, error) { c, err := app.clientCreator(app.Logger.With( "module", "abci-client", "connection", conn)) if err != nil { return nil, fmt.Errorf("error creating ABCI client (%s connection): %w", conn, err) } - if err := c.Start(); err != nil { + if err := c.Start(ctx); err != nil { return nil, fmt.Errorf("error starting ABCI client (%s connection): %w", conn, err) } return c, nil diff --git a/internal/proxy/multi_app_conn_test.go b/internal/proxy/multi_app_conn_test.go index af9c30091..9ad39cb3b 100644 --- a/internal/proxy/multi_app_conn_test.go +++ b/internal/proxy/multi_app_conn_test.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "errors" "os" "os/signal" @@ -17,31 +18,42 @@ import ( "github.com/tendermint/tendermint/libs/log" ) +type noopStoppableClientImpl struct { + abciclient.Client + count int +} + +func (c *noopStoppableClientImpl) Stop() error { c.count++; return nil } + func TestAppConns_Start_Stop(t *testing.T) { quitCh := make(<-chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + clientMock := &abcimocks.Client{} - clientMock.On("Start").Return(nil).Times(4) - clientMock.On("Stop").Return(nil).Times(4) + clientMock.On("Start", mock.Anything).Return(nil).Times(4) clientMock.On("Quit").Return(quitCh).Times(4) + cl := &noopStoppableClientImpl{Client: clientMock} creatorCallCount := 0 creator := func(logger log.Logger) (abciclient.Client, error) { creatorCallCount++ - return clientMock, nil + return cl, nil } appConns := NewAppConns(creator, log.TestingLogger(), NopMetrics()) - err := appConns.Start() + err := appConns.Start(ctx) require.NoError(t, err) time.Sleep(100 * time.Millisecond) - err = appConns.Stop() - require.NoError(t, err) + cancel() + appConns.Wait() clientMock.AssertExpectations(t) + assert.Equal(t, 4, cl.count) assert.Equal(t, 4, creatorCallCount) } @@ -56,31 +68,30 @@ func TestAppConns_Failure(t *testing.T) { } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + quitCh := make(chan struct{}) var recvQuitCh <-chan struct{} // nolint:gosimple recvQuitCh = quitCh clientMock := &abcimocks.Client{} clientMock.On("SetLogger", mock.Anything).Return() - clientMock.On("Start").Return(nil) - clientMock.On("Stop").Return(nil) + clientMock.On("Start", mock.Anything).Return(nil) clientMock.On("Quit").Return(recvQuitCh) clientMock.On("Error").Return(errors.New("EOF")).Once() + cl := &noopStoppableClientImpl{Client: clientMock} creator := func(log.Logger) (abciclient.Client, error) { - return clientMock, nil + return cl, nil } appConns := NewAppConns(creator, log.TestingLogger(), NopMetrics()) - err := appConns.Start() + err := appConns.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := appConns.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(func() { cancel(); appConns.Wait() }) // simulate failure close(quitCh) diff --git a/internal/state/execution.go b/internal/state/execution.go index dc64e6e3d..e4a1ba6c3 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -332,7 +332,7 @@ func execBlockOnProxyApp( byzVals = append(byzVals, evidence.ABCI()...) } - ctx := context.Background() + ctx := context.TODO() // Begin block var err error diff --git a/internal/state/execution_test.go b/internal/state/execution_test.go index 80b1b6e58..5da5adbb5 100644 --- a/internal/state/execution_test.go +++ b/internal/state/execution_test.go @@ -40,9 +40,12 @@ func TestApplyBlock(t *testing.T) { cc := abciclient.NewLocalCreator(app) logger := log.TestingLogger() proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) - err := proxyApp.Start() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests state, stateDB, _ := makeState(1, 1) stateStore := sm.NewStore(stateDB) @@ -62,12 +65,15 @@ func TestApplyBlock(t *testing.T) { // TestBeginBlockValidators ensures we send absent validators list. func TestBeginBlockValidators(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := &testApp{} cc := abciclient.NewLocalCreator(app) proxyApp := proxy.NewAppConns(cc, log.TestingLogger(), proxy.NopMetrics()) - err := proxyApp.Start() + + err := proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // no need to check error again state, stateDB, _ := makeState(2, 2) stateStore := sm.NewStore(stateDB) @@ -125,12 +131,14 @@ func TestBeginBlockValidators(t *testing.T) { // TestBeginBlockByzantineValidators ensures we send byzantine validators list. func TestBeginBlockByzantineValidators(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := &testApp{} cc := abciclient.NewLocalCreator(app) proxyApp := proxy.NewAppConns(cc, log.TestingLogger(), proxy.NopMetrics()) - err := proxyApp.Start() + err := proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests state, stateDB, privVals := makeState(1, 1) stateStore := sm.NewStore(stateDB) @@ -350,13 +358,15 @@ func TestUpdateValidators(t *testing.T) { // TestEndBlockValidatorUpdates ensures we update validator set and send an event. func TestEndBlockValidatorUpdates(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := &testApp{} cc := abciclient.NewLocalCreator(app) logger := log.TestingLogger() proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) - err := proxyApp.Start() + err := proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests state, stateDB, _ := makeState(1, 1) stateStore := sm.NewStore(stateDB) @@ -372,7 +382,7 @@ func TestEndBlockValidatorUpdates(t *testing.T) { ) eventBus := eventbus.NewDefault(logger) - err = eventBus.Start() + err = eventBus.Start(ctx) require.NoError(t, err) defer eventBus.Stop() //nolint:errcheck // ignore for tests @@ -405,7 +415,7 @@ func TestEndBlockValidatorUpdates(t *testing.T) { } // test we threw an event - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel = context.WithTimeout(ctx, 1*time.Second) defer cancel() msg, err := updatesSub.Next(ctx) require.NoError(t, err) @@ -420,13 +430,15 @@ func TestEndBlockValidatorUpdates(t *testing.T) { // TestEndBlockValidatorUpdatesResultingInEmptySet checks that processing validator updates that // would result in empty set causes no panic, an error is raised and NextValidators is not updated func TestEndBlockValidatorUpdatesResultingInEmptySet(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app := &testApp{} cc := abciclient.NewLocalCreator(app) logger := log.TestingLogger() proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) - err := proxyApp.Start() + err := proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests state, stateDB, _ := makeState(1, 1) stateStore := sm.NewStore(stateDB) diff --git a/internal/state/indexer/indexer_service.go b/internal/state/indexer/indexer_service.go index d5c230e81..80c4adf02 100644 --- a/internal/state/indexer/indexer_service.go +++ b/internal/state/indexer/indexer_service.go @@ -116,7 +116,7 @@ func (is *Service) publish(msg pubsub.Message) error { // indexer if the underlying event sinks support indexing. // // TODO(creachadair): Can we get rid of the "enabled" check? -func (is *Service) OnStart() error { +func (is *Service) OnStart(ctx context.Context) error { // If the event sinks support indexing, register an observer to capture // block header data for the indexer. if IndexingEnabled(is.eventSinks) { diff --git a/internal/state/indexer/indexer_service_test.go b/internal/state/indexer/indexer_service_test.go index a986530f0..879cf8820 100644 --- a/internal/state/indexer/indexer_service_test.go +++ b/internal/state/indexer/indexer_service_test.go @@ -1,6 +1,7 @@ package indexer_test import ( + "context" "database/sql" "fmt" "os" @@ -47,16 +48,15 @@ func NewIndexerService(es []indexer.EventSink, eventBus *eventbus.EventBus) *ind } func TestIndexerServiceIndexesBlocks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := tmlog.TestingLogger() // event bus eventBus := eventbus.NewDefault(logger) - err := eventBus.Start() + err := eventBus.Start(ctx) require.NoError(t, err) - t.Cleanup(func() { - if err := eventBus.Stop(); err != nil { - t.Error(err) - } - }) + t.Cleanup(eventBus.Wait) assert.False(t, indexer.KVSinkEnabled([]indexer.EventSink{})) assert.False(t, indexer.IndexingEnabled([]indexer.EventSink{})) @@ -71,13 +71,8 @@ func TestIndexerServiceIndexesBlocks(t *testing.T) { assert.True(t, indexer.IndexingEnabled(eventSinks)) service := NewIndexerService(eventSinks, eventBus) - err = service.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := service.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, service.Start(ctx)) + t.Cleanup(service.Wait) // publish block with txs err = eventBus.PublishEventNewBlockHeader(types.EventDataNewBlockHeader{ diff --git a/internal/state/indexer/sink/psql/psql_test.go b/internal/state/indexer/sink/psql/psql_test.go index f5306801f..650579f9b 100644 --- a/internal/state/indexer/sink/psql/psql_test.go +++ b/internal/state/indexer/sink/psql/psql_test.go @@ -143,6 +143,9 @@ func TestType(t *testing.T) { } func TestIndexing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + t.Run("IndexBlockEvents", func(t *testing.T) { indexer := &EventSink{store: testDB(), chainID: chainID} require.NoError(t, indexer.IndexBlockEvents(newTestBlockHeader())) @@ -154,7 +157,7 @@ func TestIndexing(t *testing.T) { verifyNotImplemented(t, "hasBlock", func() (bool, error) { return indexer.HasBlock(2) }) verifyNotImplemented(t, "block search", func() (bool, error) { - v, err := indexer.SearchBlockEvents(context.Background(), nil) + v, err := indexer.SearchBlockEvents(ctx, nil) return v != nil, err }) @@ -188,7 +191,7 @@ func TestIndexing(t *testing.T) { return txr != nil, err }) verifyNotImplemented(t, "tx search", func() (bool, error) { - txr, err := indexer.SearchTxEvents(context.Background(), nil) + txr, err := indexer.SearchTxEvents(ctx, nil) return txr != nil, err }) diff --git a/internal/state/state_test.go b/internal/state/state_test.go index fdf681294..3f989536a 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -310,6 +310,8 @@ func TestOneValidatorChangesSaveLoad(t *testing.T) { } func TestProposerFrequency(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // some explicit test cases testCases := []struct { @@ -370,7 +372,7 @@ func TestProposerFrequency(t *testing.T) { votePower := int64(mrand.Int()%maxPower) + 1 totalVotePower += votePower privVal := types.NewMockPV() - pubKey, err := privVal.GetPubKey(context.Background()) + pubKey, err := privVal.GetPubKey(ctx) require.NoError(t, err) val := types.NewValidator(pubKey, votePower) val.ProposerPriority = mrand.Int63() diff --git a/internal/state/validation_test.go b/internal/state/validation_test.go index eb0cebbb7..65c0648d4 100644 --- a/internal/state/validation_test.go +++ b/internal/state/validation_test.go @@ -28,9 +28,11 @@ import ( const validationTestsStopHeight int64 = 10 func TestValidateBlockHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + proxyApp := newTestApp() - require.NoError(t, proxyApp.Start()) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests + require.NoError(t, proxyApp.Start(ctx)) state, stateDB, privVals := makeState(3, 1) stateStore := sm.NewStore(stateDB) @@ -115,9 +117,11 @@ func TestValidateBlockHeader(t *testing.T) { } func TestValidateBlockCommit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + proxyApp := newTestApp() - require.NoError(t, proxyApp.Start()) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests + require.NoError(t, proxyApp.Start(ctx)) state, stateDB, privVals := makeState(1, 1) stateStore := sm.NewStore(stateDB) @@ -207,7 +211,7 @@ func TestValidateBlockCommit(t *testing.T) { ) require.NoError(t, err, "height %d", height) - bpvPubKey, err := badPrivVal.GetPubKey(context.Background()) + bpvPubKey, err := badPrivVal.GetPubKey(ctx) require.NoError(t, err) badVote := &types.Vote{ @@ -223,9 +227,9 @@ func TestValidateBlockCommit(t *testing.T) { g := goodVote.ToProto() b := badVote.ToProto() - err = badPrivVal.SignVote(context.Background(), chainID, g) + err = badPrivVal.SignVote(ctx, chainID, g) require.NoError(t, err, "height %d", height) - err = badPrivVal.SignVote(context.Background(), chainID, b) + err = badPrivVal.SignVote(ctx, chainID, b) require.NoError(t, err, "height %d", height) goodVote.Signature, badVote.Signature = g.Signature, b.Signature @@ -236,9 +240,11 @@ func TestValidateBlockCommit(t *testing.T) { } func TestValidateBlockEvidence(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + proxyApp := newTestApp() - require.NoError(t, proxyApp.Start()) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests + require.NoError(t, proxyApp.Start(ctx)) state, stateDB, privVals := makeState(4, 1) stateStore := sm.NewStore(stateDB) diff --git a/internal/statesync/dispatcher_test.go b/internal/statesync/dispatcher_test.go index e5a6a85cd..e717dad12 100644 --- a/internal/statesync/dispatcher_test.go +++ b/internal/statesync/dispatcher_test.go @@ -114,6 +114,10 @@ func TestDispatcherProviders(t *testing.T) { func TestPeerListBasic(t *testing.T) { t.Cleanup(leaktest.Check(t)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peerList := newPeerList() assert.Zero(t, peerList.Len()) numPeers := 10 @@ -199,6 +203,9 @@ func TestEmptyPeerListReturnsWhenContextCanceled(t *testing.T) { func TestPeerListConcurrent(t *testing.T) { t.Cleanup(leaktest.Check(t)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + peerList := newPeerList() numPeers := 10 @@ -229,7 +236,6 @@ func TestPeerListConcurrent(t *testing.T) { // we use a context with cancel and a separate go routine to wait for all // the other goroutines to close. - ctx, cancel := context.WithCancel(context.Background()) go func() { wg.Wait(); cancel() }() select { diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index 939fb409c..6566f823b 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -210,15 +210,11 @@ func NewReactor( // handle individual envelopes as to not have to deal with bounding workers or pools. // The caller must be sure to execute OnStop to ensure the outbound p2p Channels are // closed. No error is returned. -func (r *Reactor) OnStart() error { - go r.processSnapshotCh() - - go r.processChunkCh() - - go r.processBlockCh() - - go r.processParamsCh() - +func (r *Reactor) OnStart(ctx context.Context) error { + go r.processCh(r.snapshotCh, "snapshot") + go r.processCh(r.chunkCh, "chunk") + go r.processCh(r.blockCh, "light block") + go r.processCh(r.paramsCh, "consensus params") go r.processPeerUpdates() return nil @@ -607,7 +603,7 @@ func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { "chunk", msg.Index, "peer", envelope.From, ) - resp, err := r.conn.LoadSnapshotChunkSync(context.Background(), abci.RequestLoadSnapshotChunk{ + resp, err := r.conn.LoadSnapshotChunkSync(context.TODO(), abci.RequestLoadSnapshotChunk{ Height: msg.Height, Format: msg.Format, Chunk: msg.Index, @@ -815,28 +811,6 @@ func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err return err } -// processSnapshotCh initiates a blocking process where we listen for and handle -// envelopes on the SnapshotChannel. -func (r *Reactor) processSnapshotCh() { - r.processCh(r.snapshotCh, "snapshot") -} - -// processChunkCh initiates a blocking process where we listen for and handle -// envelopes on the ChunkChannel. -func (r *Reactor) processChunkCh() { - r.processCh(r.chunkCh, "chunk") -} - -// processBlockCh initiates a blocking process where we listen for and handle -// envelopes on the LightBlockChannel. -func (r *Reactor) processBlockCh() { - r.processCh(r.blockCh, "light block") -} - -func (r *Reactor) processParamsCh() { - r.processCh(r.paramsCh, "consensus params") -} - // processCh routes state sync messages to their respective handlers. Any error // encountered during message execution will result in a PeerError being sent on // the respective channel. When the reactor is stopped, we will catch the signal @@ -848,8 +822,11 @@ func (r *Reactor) processCh(ch *p2p.Channel, chName string) { select { case envelope := <-ch.In: if err := r.handleMessage(ch.ID, envelope); err != nil { - r.Logger.Error(fmt.Sprintf("failed to process %s message", chName), - "ch_id", ch.ID, "envelope", envelope, "err", err) + r.Logger.Error("failed to process message", + "err", err, + "channel", chName, + "ch_id", ch.ID, + "envelope", envelope) ch.Error <- p2p.PeerError{ NodeID: envelope.From, Err: err, @@ -857,7 +834,7 @@ func (r *Reactor) processCh(ch *p2p.Channel, chName string) { } case <-r.closeCh: - r.Logger.Debug(fmt.Sprintf("stopped listening on %s channel; closing...", chName)) + r.Logger.Debug("channel closed", "channel", chName) return } } @@ -923,7 +900,7 @@ func (r *Reactor) processPeerUpdates() { // recentSnapshots fetches the n most recent snapshots from the app func (r *Reactor) recentSnapshots(n uint32) ([]*snapshot, error) { - resp, err := r.conn.ListSnapshotsSync(context.Background(), abci.RequestListSnapshots{}) + resp, err := r.conn.ListSnapshotsSync(context.TODO(), abci.RequestListSnapshots{}) if err != nil { return nil, err } diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index b90e5fd78..8dc2d6038 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -69,6 +69,7 @@ type reactorTestSuite struct { } func setup( + ctx context.Context, t *testing.T, conn *proxymocks.AppConnSnapshot, connQuery *proxymocks.AppConnQuery, @@ -176,11 +177,11 @@ func setup( rts.reactor.metrics, ) - require.NoError(t, rts.reactor.Start()) + require.NoError(t, rts.reactor.Start(ctx)) require.True(t, rts.reactor.IsRunning()) t.Cleanup(func() { - require.NoError(t, rts.reactor.Stop()) + rts.reactor.Wait() require.False(t, rts.reactor.IsRunning()) }) @@ -188,8 +189,11 @@ func setup( } func TestReactor_Sync(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + const snapshotHeight = 7 - rts := setup(t, nil, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) chain := buildLightBlockChain(t, 1, 10, time.Now()) // app accepts any snapshot rts.conn.On("OfferSnapshotSync", ctx, mock.AnythingOfType("types.RequestOfferSnapshot")). @@ -200,7 +204,7 @@ func TestReactor_Sync(t *testing.T) { Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) // app query returns valid state app hash - rts.connQuery.On("InfoSync", ctx, proxy.RequestInfo).Return(&abci.ResponseInfo{ + rts.connQuery.On("InfoSync", mock.Anything, proxy.RequestInfo).Return(&abci.ResponseInfo{ AppVersion: 9, LastBlockHeight: snapshotHeight, LastBlockAppHash: chain[snapshotHeight+1].AppHash, @@ -213,7 +217,7 @@ func TestReactor_Sync(t *testing.T) { closeCh := make(chan struct{}) defer close(closeCh) - go handleLightBlockRequests(t, chain, rts.blockOutCh, + go handleLightBlockRequests(ctx, t, chain, rts.blockOutCh, rts.blockInCh, closeCh, 0) go graduallyAddPeers(rts.peerUpdateCh, closeCh, 1*time.Second) go handleSnapshotRequests(t, rts.snapshotOutCh, rts.snapshotInCh, closeCh, []snapshot{ @@ -226,7 +230,7 @@ func TestReactor_Sync(t *testing.T) { go handleChunkRequests(t, rts.chunkOutCh, rts.chunkInCh, closeCh, []byte("abc")) - go handleConsensusParamsRequest(t, rts.paramsOutCh, rts.paramsInCh, closeCh) + go handleConsensusParamsRequest(ctx, t, rts.paramsOutCh, rts.paramsInCh, closeCh) // update the config to use the p2p provider rts.reactor.cfg.UseP2P = true @@ -235,12 +239,15 @@ func TestReactor_Sync(t *testing.T) { rts.reactor.cfg.DiscoveryTime = 1 * time.Second // Run state sync - _, err := rts.reactor.Sync(context.Background()) + _, err := rts.reactor.Sync(ctx) require.NoError(t, err) } func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, nil, 2) rts.chunkInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -282,19 +289,23 @@ func TestReactor_ChunkRequest(t *testing.T) { }, } - for name, tc := range testcases { - tc := tc + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + for name, tc := range testcases { t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + // mock ABCI connection to return local snapshots conn := &proxymocks.AppConnSnapshot{} - conn.On("LoadSnapshotChunkSync", context.Background(), abci.RequestLoadSnapshotChunk{ + conn.On("LoadSnapshotChunkSync", mock.Anything, abci.RequestLoadSnapshotChunk{ Height: tc.request.Height, Format: tc.request.Format, Chunk: tc.request.Index, }).Return(&abci.ResponseLoadSnapshotChunk{Chunk: tc.chunk}, nil) - rts := setup(t, conn, nil, nil, 2) + rts := setup(ctx, t, conn, nil, nil, 2) rts.chunkInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -311,7 +322,10 @@ func TestReactor_ChunkRequest(t *testing.T) { } func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, nil, 2) rts.snapshotInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -360,18 +374,23 @@ func TestReactor_SnapshotsRequest(t *testing.T) { }, }, } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // mock ABCI connection to return local snapshots conn := &proxymocks.AppConnSnapshot{} - conn.On("ListSnapshotsSync", context.Background(), abci.RequestListSnapshots{}).Return(&abci.ResponseListSnapshots{ + conn.On("ListSnapshotsSync", mock.Anything, abci.RequestListSnapshots{}).Return(&abci.ResponseListSnapshots{ Snapshots: tc.snapshots, }, nil) - rts := setup(t, conn, nil, nil, 100) + rts := setup(ctx, t, conn, nil, nil, 100) rts.snapshotInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -379,7 +398,7 @@ func TestReactor_SnapshotsRequest(t *testing.T) { } if len(tc.expectResponses) > 0 { - retryUntil(t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) + retryUntil(ctx, t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) } responses := make([]*ssproto.SnapshotsResponse, len(tc.expectResponses)) @@ -395,7 +414,10 @@ func TestReactor_SnapshotsRequest(t *testing.T) { } func TestReactor_LightBlockResponse(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, nil, 2) var height int64 = 10 h := factory.MakeRandomHeader() @@ -448,7 +470,10 @@ func TestReactor_LightBlockResponse(t *testing.T) { } func TestReactor_BlockProviders(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, nil, 2) rts.peerUpdateCh <- p2p.PeerUpdate{ NodeID: types.NodeID("aa"), Status: p2p.PeerStatusUp, @@ -462,7 +487,7 @@ func TestReactor_BlockProviders(t *testing.T) { defer close(closeCh) chain := buildLightBlockChain(t, 1, 10, time.Now()) - go handleLightBlockRequests(t, chain, rts.blockOutCh, rts.blockInCh, closeCh, 0) + go handleLightBlockRequests(ctx, t, chain, rts.blockOutCh, rts.blockInCh, closeCh, 0) peers := rts.reactor.peers.All() require.Len(t, peers, 2) @@ -479,7 +504,7 @@ func TestReactor_BlockProviders(t *testing.T) { go func(t *testing.T, p provider.Provider) { defer wg.Done() for height := 2; height < 10; height++ { - lb, err := p.LightBlock(context.Background(), int64(height)) + lb, err := p.LightBlock(ctx, int64(height)) require.NoError(t, err) require.NotNil(t, lb) require.Equal(t, height, int(lb.Height)) @@ -487,7 +512,6 @@ func TestReactor_BlockProviders(t *testing.T) { }(t, p) } - ctx, cancel := context.WithCancel(context.Background()) go func() { wg.Wait(); cancel() }() select { @@ -501,7 +525,10 @@ func TestReactor_BlockProviders(t *testing.T) { } func TestReactor_StateProviderP2P(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, nil, 2) // make syncer non nil else test won't think we are state syncing rts.reactor.syncer = rts.syncer peerA := types.NodeID(strings.Repeat("a", 2*types.NodeIDByteLength)) @@ -519,8 +546,8 @@ func TestReactor_StateProviderP2P(t *testing.T) { defer close(closeCh) chain := buildLightBlockChain(t, 1, 10, time.Now()) - go handleLightBlockRequests(t, chain, rts.blockOutCh, rts.blockInCh, closeCh, 0) - go handleConsensusParamsRequest(t, rts.paramsOutCh, rts.paramsInCh, closeCh) + go handleLightBlockRequests(ctx, t, chain, rts.blockOutCh, rts.blockInCh, closeCh, 0) + go handleConsensusParamsRequest(ctx, t, rts.paramsOutCh, rts.paramsInCh, closeCh) rts.reactor.cfg.UseP2P = true rts.reactor.cfg.TrustHeight = 1 @@ -533,10 +560,7 @@ func TestReactor_StateProviderP2P(t *testing.T) { } require.True(t, rts.reactor.peers.Len() >= 2, "peer network not configured") - bctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ictx, cancel := context.WithTimeout(bctx, time.Second) + ictx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() rts.reactor.mtx.Lock() @@ -545,7 +569,7 @@ func TestReactor_StateProviderP2P(t *testing.T) { require.NoError(t, err) rts.reactor.syncer.stateProvider = rts.reactor.stateProvider - actx, cancel := context.WithTimeout(bctx, 10*time.Second) + actx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() appHash, err := rts.reactor.stateProvider.AppHash(actx, 5) @@ -569,13 +593,19 @@ func TestReactor_StateProviderP2P(t *testing.T) { } func TestReactor_Backfill(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // test backfill algorithm with varying failure rates [0, 10] failureRates := []int{0, 2, 9} for _, failureRate := range failureRates { failureRate := failureRate t.Run(fmt.Sprintf("failure rate: %d", failureRate), func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + t.Cleanup(leaktest.CheckTimeout(t, 1*time.Minute)) - rts := setup(t, nil, nil, nil, 21) + rts := setup(ctx, t, nil, nil, nil, 21) var ( startHeight int64 = 20 @@ -605,11 +635,11 @@ func TestReactor_Backfill(t *testing.T) { closeCh := make(chan struct{}) defer close(closeCh) - go handleLightBlockRequests(t, chain, rts.blockOutCh, + go handleLightBlockRequests(ctx, t, chain, rts.blockOutCh, rts.blockInCh, closeCh, failureRate) err := rts.reactor.backfill( - context.Background(), + ctx, factory.DefaultTestChainID, startHeight, stopHeight, @@ -644,8 +674,8 @@ func TestReactor_Backfill(t *testing.T) { // retryUntil will continue to evaluate fn and will return successfully when true // or fail when the timeout is reached. -func retryUntil(t *testing.T, fn func() bool, timeout time.Duration) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func retryUntil(ctx context.Context, t *testing.T, fn func() bool, timeout time.Duration) { + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() for { @@ -656,7 +686,9 @@ func retryUntil(t *testing.T, fn func() bool, timeout time.Duration) { } } -func handleLightBlockRequests(t *testing.T, +func handleLightBlockRequests( + ctx context.Context, + t *testing.T, chain map[int64]*types.LightBlock, receiving chan p2p.Envelope, sending chan p2p.Envelope, @@ -666,6 +698,8 @@ func handleLightBlockRequests(t *testing.T, errorCount := 0 for { select { + case <-ctx.Done(): + return case envelope := <-receiving: if msg, ok := envelope.Message.(*ssproto.LightBlockRequest); ok { if requests%10 >= failureRate { @@ -709,13 +743,24 @@ func handleLightBlockRequests(t *testing.T, } } -func handleConsensusParamsRequest(t *testing.T, receiving, sending chan p2p.Envelope, closeCh chan struct{}) { +func handleConsensusParamsRequest( + ctx context.Context, + t *testing.T, + receiving, sending chan p2p.Envelope, + closeCh chan struct{}, +) { t.Helper() params := types.DefaultConsensusParams() paramsProto := params.ToProto() for { select { + case <-ctx.Done(): + return case envelope := <-receiving: + if ctx.Err() != nil { + return + } + t.Log("received consensus params request") msg, ok := envelope.Message.(*ssproto.ParamsRequest) require.True(t, ok) diff --git a/internal/statesync/syncer.go b/internal/statesync/syncer.go index b4212961a..f266017dd 100644 --- a/internal/statesync/syncer.go +++ b/internal/statesync/syncer.go @@ -565,7 +565,7 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { // verifyApp verifies the sync, checking the app hash and last block height. It returns the // app version, which should be returned as part of the initial state. func (s *syncer) verifyApp(snapshot *snapshot) (uint64, error) { - resp, err := s.connQuery.InfoSync(context.Background(), proxy.RequestInfo) + resp, err := s.connQuery.InfoSync(context.TODO(), proxy.RequestInfo) if err != nil { return 0, fmt.Errorf("failed to query ABCI app for appHash: %w", err) } diff --git a/internal/statesync/syncer_test.go b/internal/statesync/syncer_test.go index ad902a54c..4c240830f 100644 --- a/internal/statesync/syncer_test.go +++ b/internal/statesync/syncer_test.go @@ -22,9 +22,10 @@ import ( "github.com/tendermint/tendermint/version" ) -var ctx = context.Background() - func TestSyncer_SyncAny(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + state := sm.State{ ChainID: "chain", Version: sm.Version{ @@ -68,7 +69,7 @@ func TestSyncer_SyncAny(t *testing.T) { peerAID := types.NodeID("aa") peerBID := types.NodeID("bb") peerCID := types.NodeID("cc") - rts := setup(t, connSnapshot, connQuery, stateProvider, 3) + rts := setup(ctx, t, connSnapshot, connQuery, stateProvider, 3) rts.reactor.syncer = rts.syncer @@ -110,7 +111,7 @@ func TestSyncer_SyncAny(t *testing.T) { // We start a sync, with peers sending back chunks when requested. We first reject the snapshot // with height 2 format 2, and accept the snapshot at height 1. - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + connSnapshot.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: &abci.Snapshot{ Height: 2, Format: 2, @@ -119,7 +120,7 @@ func TestSyncer_SyncAny(t *testing.T) { }, AppHash: []byte("app_hash_2"), }).Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT_FORMAT}, nil) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + connSnapshot.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: &abci.Snapshot{ Height: s.Height, Format: s.Format, @@ -160,7 +161,7 @@ func TestSyncer_SyncAny(t *testing.T) { // The first time we're applying chunk 2 we tell it to retry the snapshot and discard chunk 1, // which should cause it to keep the existing chunk 0 and 2, and restart restoration from // beginning. We also wait for a little while, to exercise the retry logic in fetchChunks(). - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + connSnapshot.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{1, 1, 2}, }).Once().Run(func(args mock.Arguments) { time.Sleep(2 * time.Second) }).Return( &abci.ResponseApplySnapshotChunk{ @@ -168,16 +169,16 @@ func TestSyncer_SyncAny(t *testing.T) { RefetchChunks: []uint32{1}, }, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + connSnapshot.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: []byte{1, 1, 0}, }).Times(2).Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + connSnapshot.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 1, Chunk: []byte{1, 1, 1}, }).Times(2).Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + connSnapshot.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{1, 1, 2}, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connQuery.On("InfoSync", ctx, proxy.RequestInfo).Return(&abci.ResponseInfo{ + connQuery.On("InfoSync", mock.Anything, proxy.RequestInfo).Return(&abci.ResponseInfo{ AppVersion: 9, LastBlockHeight: 1, LastBlockAppHash: []byte("app_hash"), @@ -217,7 +218,10 @@ func TestSyncer_SyncAny_noSnapshots(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) _, _, err := rts.syncer.SyncAny(ctx, 0, func() {}) require.Equal(t, errNoSnapshots, err) @@ -227,7 +231,10 @@ func TestSyncer_SyncAny_abort(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} peerID := types.NodeID("aa") @@ -235,7 +242,7 @@ func TestSyncer_SyncAny_abort(t *testing.T) { _, err := rts.syncer.AddSnapshot(peerID, s) require.NoError(t, err) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) @@ -248,7 +255,10 @@ func TestSyncer_SyncAny_reject(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) // s22 is tried first, then s12, then s11, then errNoSnapshots s22 := &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} @@ -266,15 +276,15 @@ func TestSyncer_SyncAny_reject(t *testing.T) { _, err = rts.syncer.AddSnapshot(peerID, s11) require.NoError(t, err) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s22), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s12), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) @@ -287,7 +297,10 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) // s22 is tried first, which reject s22 and s12, then s11 will abort. s22 := &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} @@ -305,11 +318,11 @@ func TestSyncer_SyncAny_reject_format(t *testing.T) { _, err = rts.syncer.AddSnapshot(peerID, s11) require.NoError(t, err) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s22), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT_FORMAT}, nil) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) @@ -322,7 +335,10 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) peerAID := types.NodeID("aa") peerBID := types.NodeID("bb") @@ -351,11 +367,11 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { _, err = rts.syncer.AddSnapshot(peerCID, sbc) require.NoError(t, err) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(sbc), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT_SENDER}, nil) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(sa), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) @@ -368,7 +384,10 @@ func TestSyncer_SyncAny_abciError(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rts := setup(ctx, t, nil, nil, stateProvider, 2) errBoom := errors.New("boom") s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} @@ -378,7 +397,7 @@ func TestSyncer_SyncAny_abciError(t *testing.T) { _, err := rts.syncer.AddSnapshot(peerID, s) require.NoError(t, err) - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(nil, errBoom) @@ -405,16 +424,23 @@ func TestSyncer_offerSnapshot(t *testing.T) { "error": {0, boom, boom}, "unknown non-zero": {9, nil, unknownErr}, } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + rts := setup(ctx, t, nil, nil, stateProvider, 2) s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}, trustedAppHash: []byte("app_hash")} - rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", mock.Anything, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Return(&abci.ResponseOfferSnapshot{Result: tc.result}, tc.err) @@ -451,13 +477,20 @@ func TestSyncer_applyChunks_Results(t *testing.T) { "error": {0, boom, boom}, "unknown non-zero": {9, nil, unknownErr}, } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + rts := setup(ctx, t, nil, nil, stateProvider, 2) body := []byte{1, 2, 3} chunks, err := newChunkQueue(&snapshot{Height: 1, Format: 1, Chunks: 1}, "") @@ -468,11 +501,11 @@ func TestSyncer_applyChunks_Results(t *testing.T) { _, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 0, Chunk: body}) require.NoError(t, err) - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: body, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: tc.result}, tc.err) if tc.result == abci.ResponseApplySnapshotChunk_RETRY { - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: body, }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) @@ -505,13 +538,19 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + rts := setup(ctx, t, nil, nil, stateProvider, 2) chunks, err := newChunkQueue(&snapshot{Height: 1, Format: 1, Chunks: 3}, "") require.NoError(t, err) @@ -529,13 +568,13 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { require.NoError(t, err) // The first two chunks are accepted, before the last one asks for 1 to be refetched - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: []byte{0}, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 1, Chunk: []byte{1}, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{2}, }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: tc.result, @@ -570,13 +609,19 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { "retry_snapshot": {abci.ResponseApplySnapshotChunk_RETRY_SNAPSHOT}, "reject_snapshot": {abci.ResponseApplySnapshotChunk_REJECT_SNAPSHOT}, } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - rts := setup(t, nil, nil, stateProvider, 2) + rts := setup(ctx, t, nil, nil, stateProvider, 2) // Set up three peers across two snapshots, and ask for one of them to be banned. // It should be banned from all snapshots. @@ -623,13 +668,13 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { require.NoError(t, err) // The first two chunks are accepted, before the last one asks for b sender to be rejected - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: []byte{0}, Sender: "aa", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 1, Chunk: []byte{1}, Sender: "bb", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{2}, Sender: "cc", }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: tc.result, @@ -638,7 +683,7 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { // On retry, the last chunk will be tried again, so we just accept it then. if tc.result == abci.ResponseApplySnapshotChunk_RETRY { - rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", mock.Anything, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{2}, Sender: "cc", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) } @@ -693,12 +738,18 @@ func TestSyncer_verifyApp(t *testing.T) { }, nil, errVerifyFailed}, "error": {nil, boom, boom}, } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - rts := setup(t, nil, nil, nil, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - rts.connQuery.On("InfoSync", ctx, proxy.RequestInfo).Return(tc.response, tc.err) + rts := setup(ctx, t, nil, nil, nil, 2) + + rts.connQuery.On("InfoSync", mock.Anything, proxy.RequestInfo).Return(tc.response, tc.err) version, err := rts.syncer.verifyApp(s) unwrapped := errors.Unwrap(err) if unwrapped != nil { diff --git a/libs/events/event_cache_test.go b/libs/events/event_cache_test.go index d6199bc80..a5bb975c9 100644 --- a/libs/events/event_cache_test.go +++ b/libs/events/event_cache_test.go @@ -1,6 +1,7 @@ package events import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -8,8 +9,11 @@ import ( ) func TestEventCache_Flush(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() + err := evsw.Start(ctx) require.NoError(t, err) err = evsw.AddListenerForEvent("nothingness", "", func(data EventData) { diff --git a/libs/events/events.go b/libs/events/events.go index 146a9cfa7..f6151e734 100644 --- a/libs/events/events.go +++ b/libs/events/events.go @@ -2,6 +2,7 @@ package events import ( + "context" "fmt" tmsync "github.com/tendermint/tendermint/internal/libs/sync" @@ -45,6 +46,7 @@ type Fireable interface { type EventSwitch interface { service.Service Fireable + Stop() error AddListenerForEvent(listenerID, eventValue string, cb EventCallback) error RemoveListenerForEvent(event string, listenerID string) @@ -68,7 +70,7 @@ func NewEventSwitch() EventSwitch { return evsw } -func (evsw *eventSwitch) OnStart() error { +func (evsw *eventSwitch) OnStart(ctx context.Context) error { return nil } diff --git a/libs/events/events_test.go b/libs/events/events_test.go index 9e21e0235..0e8667908 100644 --- a/libs/events/events_test.go +++ b/libs/events/events_test.go @@ -1,6 +1,7 @@ package events import ( + "context" "fmt" "testing" "time" @@ -14,23 +15,20 @@ import ( // TestAddListenerForEventFireOnce sets up an EventSwitch, subscribes a single // listener to an event, and sends a string "data". func TestAddListenerForEventFireOnce(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) messages := make(chan EventData) - err = evsw.AddListenerForEvent("listener", "event", + require.NoError(t, evsw.AddListenerForEvent("listener", "event", func(data EventData) { // test there's no deadlock if we remove the listener inside a callback evsw.RemoveListener("listener") messages <- data - }) - require.NoError(t, err) + })) go evsw.FireEvent("event", "data") received := <-messages if received != "data" { @@ -41,24 +39,21 @@ func TestAddListenerForEventFireOnce(t *testing.T) { // TestAddListenerForEventFireMany sets up an EventSwitch, subscribes a single // listener to an event, and sends a thousand integers. func TestAddListenerForEventFireMany(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) doneSum := make(chan uint64) doneSending := make(chan uint64) numbers := make(chan uint64, 4) // subscribe one listener for one event - err = evsw.AddListenerForEvent("listener", "event", + require.NoError(t, evsw.AddListenerForEvent("listener", "event", func(data EventData) { numbers <- data.(uint64) - }) - require.NoError(t, err) + })) // collect received events go sumReceivedNumbers(numbers, doneSum) // go fire events @@ -75,14 +70,12 @@ func TestAddListenerForEventFireMany(t *testing.T) { // listener to three different events and sends a thousand integers for each // of the three events. func TestAddListenerForDifferentEvents(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) doneSum := make(chan uint64) doneSending1 := make(chan uint64) @@ -90,21 +83,18 @@ func TestAddListenerForDifferentEvents(t *testing.T) { doneSending3 := make(chan uint64) numbers := make(chan uint64, 4) // subscribe one listener to three events - err = evsw.AddListenerForEvent("listener", "event1", + require.NoError(t, evsw.AddListenerForEvent("listener", "event1", func(data EventData) { numbers <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener", "event2", func(data EventData) { numbers <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener", "event3", + })) + require.NoError(t, evsw.AddListenerForEvent("listener", "event3", func(data EventData) { numbers <- data.(uint64) - }) - require.NoError(t, err) + })) // collect received events go sumReceivedNumbers(numbers, doneSum) // go fire events @@ -127,15 +117,13 @@ func TestAddListenerForDifferentEvents(t *testing.T) { // listener to two of those three events, and then sends a thousand integers // for each of the three events. func TestAddDifferentListenerForDifferentEvents(t *testing.T) { - evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + evsw := NewEventSwitch() + require.NoError(t, evsw.Start(ctx)) + + t.Cleanup(evsw.Wait) doneSum1 := make(chan uint64) doneSum2 := make(chan uint64) @@ -145,31 +133,26 @@ func TestAddDifferentListenerForDifferentEvents(t *testing.T) { numbers1 := make(chan uint64, 4) numbers2 := make(chan uint64, 4) // subscribe two listener to three events - err = evsw.AddListenerForEvent("listener1", "event1", + require.NoError(t, evsw.AddListenerForEvent("listener1", "event1", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener1", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener1", "event2", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener1", "event3", + })) + require.NoError(t, evsw.AddListenerForEvent("listener1", "event3", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener2", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener2", "event2", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener2", "event3", + })) + require.NoError(t, evsw.AddListenerForEvent("listener2", "event3", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) + })) // collect received events for listener1 go sumReceivedNumbers(numbers1, doneSum1) // collect received events for listener2 @@ -199,14 +182,12 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) { roundCount = 2000 ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -249,14 +230,12 @@ func TestAddAndRemoveListenerConcurrency(t *testing.T) { // two events, fires a thousand integers for the first event, then unsubscribes // the listener and fires a thousand integers for the second event. func TestAddAndRemoveListener(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) doneSum1 := make(chan uint64) doneSum2 := make(chan uint64) @@ -265,16 +244,14 @@ func TestAddAndRemoveListener(t *testing.T) { numbers1 := make(chan uint64, 4) numbers2 := make(chan uint64, 4) // subscribe two listener to three events - err = evsw.AddListenerForEvent("listener", "event1", + require.NoError(t, evsw.AddListenerForEvent("listener", "event1", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener", "event2", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) + })) // collect received events for event1 go sumReceivedNumbers(numbers1, doneSum1) // collect received events for event2 @@ -300,29 +277,23 @@ func TestAddAndRemoveListener(t *testing.T) { // TestRemoveListener does basic tests on adding and removing func TestRemoveListener(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) count := 10 sum1, sum2 := 0, 0 // add some listeners and make sure they work - err = evsw.AddListenerForEvent("listener", "event1", + require.NoError(t, evsw.AddListenerForEvent("listener", "event1", func(data EventData) { sum1++ - }) - require.NoError(t, err) - - err = evsw.AddListenerForEvent("listener", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener", "event2", func(data EventData) { sum2++ - }) - require.NoError(t, err) + })) for i := 0; i < count; i++ { evsw.FireEvent("event1", true) @@ -361,14 +332,11 @@ func TestRemoveListener(t *testing.T) { // NOTE: it is important to run this test with race conditions tracking on, // `go test -race`, to examine for possible race conditions. func TestRemoveListenersAsync(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() evsw := NewEventSwitch() - err := evsw.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := evsw.Stop(); err != nil { - t.Error(err) - } - }) + require.NoError(t, evsw.Start(ctx)) + t.Cleanup(evsw.Wait) doneSum1 := make(chan uint64) doneSum2 := make(chan uint64) @@ -378,36 +346,30 @@ func TestRemoveListenersAsync(t *testing.T) { numbers1 := make(chan uint64, 4) numbers2 := make(chan uint64, 4) // subscribe two listener to three events - err = evsw.AddListenerForEvent("listener1", "event1", + require.NoError(t, evsw.AddListenerForEvent("listener1", "event1", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener1", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener1", "event2", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener1", "event3", + })) + require.NoError(t, evsw.AddListenerForEvent("listener1", "event3", func(data EventData) { numbers1 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener2", "event1", + })) + require.NoError(t, evsw.AddListenerForEvent("listener2", "event1", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener2", "event2", + })) + require.NoError(t, evsw.AddListenerForEvent("listener2", "event2", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) - err = evsw.AddListenerForEvent("listener2", "event3", + })) + require.NoError(t, evsw.AddListenerForEvent("listener2", "event3", func(data EventData) { numbers2 <- data.(uint64) - }) - require.NoError(t, err) + })) // collect received events for event1 go sumReceivedNumbers(numbers1, doneSum1) // collect received events for event2 diff --git a/libs/pubsub/example_test.go b/libs/pubsub/example_test.go index cae644f7b..4d317215f 100644 --- a/libs/pubsub/example_test.go +++ b/libs/pubsub/example_test.go @@ -12,8 +12,9 @@ import ( ) func TestExample(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := newTestServer(ctx, t) sub := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: "example-client", diff --git a/libs/pubsub/pubsub.go b/libs/pubsub/pubsub.go index c1224c642..930dd47bc 100644 --- a/libs/pubsub/pubsub.go +++ b/libs/pubsub/pubsub.go @@ -341,7 +341,7 @@ func (s *Server) OnStop() { s.stop() } func (s *Server) Wait() { <-s.exited; s.BaseService.Wait() } // OnStart implements Service.OnStart by starting the server. -func (s *Server) OnStart() error { s.run(); return nil } +func (s *Server) OnStart(ctx context.Context) error { s.run(); return nil } // OnReset implements Service.OnReset. It has no effect for this service. func (s *Server) OnReset() error { return nil } diff --git a/libs/pubsub/pubsub_test.go b/libs/pubsub/pubsub_test.go index 8dcf8b3d9..be7f3e6e0 100644 --- a/libs/pubsub/pubsub_test.go +++ b/libs/pubsub/pubsub_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/libs/log" @@ -20,8 +19,10 @@ const ( ) func TestSubscribeWithArgs(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) t.Run("DefaultLimit", func(t *testing.T) { sub := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ @@ -47,8 +48,10 @@ func TestSubscribeWithArgs(t *testing.T) { } func TestObserver(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) done := make(chan struct{}) var got interface{} @@ -65,8 +68,10 @@ func TestObserver(t *testing.T) { } func TestObserverErrors(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) require.Error(t, s.Observe(ctx, nil, query.Empty{})) require.NoError(t, s.Observe(ctx, func(pubsub.Message) error { return nil })) @@ -74,8 +79,10 @@ func TestObserverErrors(t *testing.T) { } func TestPublishDoesNotBlock(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) sub := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: clientID, @@ -100,8 +107,10 @@ func TestPublishDoesNotBlock(t *testing.T) { } func TestSubscribeErrors(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) t.Run("EmptyQueryErr", func(t *testing.T) { _, err := s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ClientID: clientID}) @@ -118,8 +127,10 @@ func TestSubscribeErrors(t *testing.T) { } func TestSlowSubscriber(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) sub := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: clientID, @@ -137,8 +148,10 @@ func TestSlowSubscriber(t *testing.T) { } func TestDifferentClients(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) sub1 := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: "client-1", @@ -188,8 +201,10 @@ func TestDifferentClients(t *testing.T) { } func TestSubscribeDuplicateKeys(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) testCases := []struct { query string @@ -241,8 +256,10 @@ func TestSubscribeDuplicateKeys(t *testing.T) { } func TestClientSubscribesTwice(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) q := query.MustParse("tm.events.type='NewBlock'") events := []abci.Event{{ @@ -274,8 +291,10 @@ func TestClientSubscribesTwice(t *testing.T) { } func TestUnsubscribe(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) sub := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: clientID, @@ -296,8 +315,10 @@ func TestUnsubscribe(t *testing.T) { } func TestClientUnsubscribesTwice(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: clientID, @@ -315,8 +336,10 @@ func TestClientUnsubscribesTwice(t *testing.T) { } func TestResubscribe(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) args := pubsub.SubscribeArgs{ ClientID: clientID, @@ -336,8 +359,10 @@ func TestResubscribe(t *testing.T) { } func TestUnsubscribeAll(t *testing.T) { - s := newTestServer(t) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newTestServer(ctx, t) sub1 := newTestSub(t).must(s.SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: clientID, @@ -364,28 +389,27 @@ func TestBufferCapacity(t *testing.T) { require.Equal(t, 2, s.BufferCapacity()) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() require.NoError(t, s.Publish(ctx, "Nighthawk")) require.NoError(t, s.Publish(ctx, "Sage")) - ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() require.ErrorIs(t, s.Publish(ctx, "Ironclad"), context.DeadlineExceeded) } -func newTestServer(t testing.TB) *pubsub.Server { +func newTestServer(ctx context.Context, t testing.TB) *pubsub.Server { t.Helper() s := pubsub.NewServer(func(s *pubsub.Server) { s.Logger = log.TestingLogger() }) - require.NoError(t, s.Start()) - t.Cleanup(func() { - assert.NoError(t, s.Stop()) - }) + require.NoError(t, s.Start(ctx)) + t.Cleanup(s.Wait) return s } diff --git a/libs/service/service.go b/libs/service/service.go index 88c25d804..0fc26dbdb 100644 --- a/libs/service/service.go +++ b/libs/service/service.go @@ -1,8 +1,8 @@ package service import ( + "context" "errors" - "fmt" "sync/atomic" "github.com/tendermint/tendermint/libs/log" @@ -22,22 +22,10 @@ var ( // Service defines a service that can be started, stopped, and reset. type Service interface { - // Start the service. - // If it's already started or stopped, will return an error. - // If OnStart() returns an error, it's returned by Start() - Start() error - OnStart() error - - // Stop the service. - // If it's already stopped, will return an error. - // OnStop must never error. - Stop() error - OnStop() - - // Reset the service. - // Panics by default - must be overwritten to enable reset. - Reset() error - OnReset() error + // Start is called to start the service, which should run until + // the context terminates. If the service is already running, Start + // must report an error. + Start(context.Context) error // Return true if the service is running IsRunning() bool @@ -52,6 +40,18 @@ type Service interface { Wait() } +// Implementation describes the implementation that the +// BaseService implementation wraps. +type Implementation interface { + Service + + // Called by the Services Start Method + OnStart(context.Context) error + + // Called when the service's context is canceled. + OnStop() +} + /* Classical-inheritance-style service declarations. Services can be started, then stopped, then optionally restarted. @@ -82,7 +82,7 @@ Typical usage: return fs } - func (fs *FooService) OnStart() error { + func (fs *FooService) OnStart(ctx context.Context) error { fs.BaseService.OnStart() // Always call the overridden method. // initialize private fields // start subroutines, etc. @@ -102,11 +102,11 @@ type BaseService struct { quit chan struct{} // The "subclass" of BaseService - impl Service + impl Implementation } // NewBaseService creates a new BaseService. -func NewBaseService(logger log.Logger, name string, impl Service) *BaseService { +func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseService { if logger == nil { logger = log.NewNopLogger() } @@ -119,10 +119,10 @@ func NewBaseService(logger log.Logger, name string, impl Service) *BaseService { } } -// Start implements Service by calling OnStart (if defined). An error will be -// returned if the service is already running or stopped. Not to start the -// stopped service, you need to call Reset. -func (bs *BaseService) Start() error { +// Start starts the Service and calls its OnStart method. An error will be +// returned if the service is already running or stopped. To restart a +// stopped service, call Reset. +func (bs *BaseService) Start(ctx context.Context) error { if atomic.CompareAndSwapUint32(&bs.started, 0, 1) { if atomic.LoadUint32(&bs.stopped) == 1 { bs.Logger.Error("not starting service; already stopped", "service", bs.name, "impl", bs.impl.String()) @@ -132,11 +132,26 @@ func (bs *BaseService) Start() error { bs.Logger.Info("starting service", "service", bs.name, "impl", bs.impl.String()) - if err := bs.impl.OnStart(); err != nil { + if err := bs.impl.OnStart(ctx); err != nil { // revert flag atomic.StoreUint32(&bs.started, 0) return err } + + go func(ctx context.Context) { + <-ctx.Done() + if err := bs.Stop(); err != nil { + bs.Logger.Error("stopped service", + "err", err.Error(), + "service", bs.name, + "impl", bs.impl.String()) + } + + bs.Logger.Info("stopped service", + "service", bs.name, + "impl", bs.impl.String()) + }(ctx) + return nil } @@ -147,7 +162,7 @@ func (bs *BaseService) Start() error { // OnStart implements Service by doing nothing. // NOTE: Do not put anything in here, // that way users don't need to call BaseService.OnStart() -func (bs *BaseService) OnStart() error { return nil } +func (bs *BaseService) OnStart(ctx context.Context) error { return nil } // Stop implements Service by calling OnStop (if defined) and closing quit // channel. An error will be returned if the service is already stopped. @@ -175,26 +190,6 @@ func (bs *BaseService) Stop() error { // that way users don't need to call BaseService.OnStop() func (bs *BaseService) OnStop() {} -// Reset implements Service by calling OnReset callback (if defined). An error -// will be returned if the service is running. -func (bs *BaseService) Reset() error { - if !atomic.CompareAndSwapUint32(&bs.stopped, 1, 0) { - bs.Logger.Debug("cannot reset service; not stopped", "service", bs.name, "impl", bs.impl.String()) - return fmt.Errorf("can't reset running %s", bs.name) - } - - // whether or not we've started, we can reset - atomic.CompareAndSwapUint32(&bs.started, 1, 0) - - bs.quit = make(chan struct{}) - return bs.impl.OnReset() -} - -// OnReset implements Service by panicking. -func (bs *BaseService) OnReset() error { - panic("The service cannot be reset") -} - // IsRunning implements Service by returning true or false depending on the // service's state. func (bs *BaseService) IsRunning() bool { @@ -202,16 +197,10 @@ func (bs *BaseService) IsRunning() bool { } // Wait blocks until the service is stopped. -func (bs *BaseService) Wait() { - <-bs.quit -} +func (bs *BaseService) Wait() { <-bs.quit } // String implements Service by returning a string representation of the service. -func (bs *BaseService) String() string { - return bs.name -} +func (bs *BaseService) String() string { return bs.name } // Quit Implements Service by returning a quit channel. -func (bs *BaseService) Quit() <-chan struct{} { - return bs.quit -} +func (bs *BaseService) Quit() <-chan struct{} { return bs.quit } diff --git a/libs/service/service_test.go b/libs/service/service_test.go index 7abc6f4fb..dc5d0ccb1 100644 --- a/libs/service/service_test.go +++ b/libs/service/service_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "testing" "time" @@ -16,9 +17,12 @@ func (testService) OnReset() error { } func TestBaseServiceWait(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ts := &testService{} ts.BaseService = *NewBaseService(nil, "TestService", ts) - err := ts.Start() + err := ts.Start(ctx) require.NoError(t, err) waitFinished := make(chan struct{}) @@ -36,22 +40,3 @@ func TestBaseServiceWait(t *testing.T) { t.Fatal("expected Wait() to finish within 100 ms.") } } - -func TestBaseServiceReset(t *testing.T) { - ts := &testService{} - ts.BaseService = *NewBaseService(nil, "TestService", ts) - err := ts.Start() - require.NoError(t, err) - - err = ts.Reset() - require.Error(t, err, "expected cant reset service error") - - err = ts.Stop() - require.NoError(t, err) - - err = ts.Reset() - require.NoError(t, err) - - err = ts.Start() - require.NoError(t, err) -} diff --git a/libs/strings/string.go b/libs/strings/string.go index b09c00063..6cc0b18ee 100644 --- a/libs/strings/string.go +++ b/libs/strings/string.go @@ -55,6 +55,10 @@ func SplitAndTrim(s, sep, cutset string) []string { return spl } +// TrimSpace removes all leading and trailing whitespace from the +// string. +func TrimSpace(s string) string { return strings.TrimSpace(s) } + // Returns true if s is a non-empty printable non-tab ascii character. func IsASCIIText(s string) bool { if len(s) == 0 { diff --git a/light/proxy/proxy.go b/light/proxy/proxy.go index 6f2622588..f8c183308 100644 --- a/light/proxy/proxy.go +++ b/light/proxy/proxy.go @@ -49,8 +49,8 @@ func NewProxy( // routes to proxy via Client, and starts up an HTTP server on the TCP network // address p.Addr. // See http#Server#ListenAndServe. -func (p *Proxy) ListenAndServe() error { - listener, mux, err := p.listen() +func (p *Proxy) ListenAndServe(ctx context.Context) error { + listener, mux, err := p.listen(ctx) if err != nil { return err } @@ -67,8 +67,8 @@ func (p *Proxy) ListenAndServe() error { // ListenAndServeTLS acts identically to ListenAndServe, except that it expects // HTTPS connections. // See http#Server#ListenAndServeTLS. -func (p *Proxy) ListenAndServeTLS(certFile, keyFile string) error { - listener, mux, err := p.listen() +func (p *Proxy) ListenAndServeTLS(ctx context.Context, certFile, keyFile string) error { + listener, mux, err := p.listen(ctx) if err != nil { return err } @@ -84,7 +84,7 @@ func (p *Proxy) ListenAndServeTLS(certFile, keyFile string) error { ) } -func (p *Proxy) listen() (net.Listener, *http.ServeMux, error) { +func (p *Proxy) listen(ctx context.Context) (net.Listener, *http.ServeMux, error) { mux := http.NewServeMux() // 1) Register regular routes. @@ -107,7 +107,7 @@ func (p *Proxy) listen() (net.Listener, *http.ServeMux, error) { // 3) Start a client. if !p.Client.IsRunning() { - if err := p.Client.Start(); err != nil { + if err := p.Client.Start(ctx); err != nil { return nil, mux, fmt.Errorf("can't start client: %w", err) } } diff --git a/light/rpc/client.go b/light/rpc/client.go index dc745542e..6143338f4 100644 --- a/light/rpc/client.go +++ b/light/rpc/client.go @@ -98,9 +98,9 @@ func NewClient(next rpcclient.Client, lc LightClient, opts ...Option) *Client { return c } -func (c *Client) OnStart() error { +func (c *Client) OnStart(ctx context.Context) error { if !c.next.IsRunning() { - return c.next.Start() + return c.next.Start(ctx) } return nil } diff --git a/node/node.go b/node/node.go index a0b77823b..7d3b56b47 100644 --- a/node/node.go +++ b/node/node.go @@ -82,7 +82,11 @@ type nodeImpl struct { // newDefaultNode returns a Tendermint node with default settings for the // PrivValidator, ClientCreator, GenesisDoc, and DBProvider. // It implements NodeProvider. -func newDefaultNode(cfg *config.Config, logger log.Logger) (service.Service, error) { +func newDefaultNode( + ctx context.Context, + cfg *config.Config, + logger log.Logger, +) (service.Service, error) { nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) if err != nil { return nil, fmt.Errorf("failed to load or gen node key %s: %w", cfg.NodeKeyFile(), err) @@ -108,7 +112,9 @@ func newDefaultNode(cfg *config.Config, logger log.Logger) (service.Service, err appClient, _ := proxy.DefaultClientCreator(logger, cfg.ProxyApp, cfg.ABCI, cfg.DBDir()) - return makeNode(cfg, + return makeNode( + ctx, + cfg, pval, nodeKey, appClient, @@ -119,7 +125,9 @@ func newDefaultNode(cfg *config.Config, logger log.Logger) (service.Service, err } // makeNode returns a new, ready to go, Tendermint Node. -func makeNode(cfg *config.Config, +func makeNode( + ctx context.Context, + cfg *config.Config, privValidator types.PrivValidator, nodeKey types.NodeKey, clientCreator abciclient.Creator, @@ -127,7 +135,10 @@ func makeNode(cfg *config.Config, dbProvider config.DBProvider, logger log.Logger, ) (service.Service, error) { - closers := []closer{} + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + + closers := []closer{convertCancelCloser(cancel)} blockStore, stateDB, dbCloser, err := initDBs(cfg, dbProvider) if err != nil { @@ -157,7 +168,7 @@ func makeNode(cfg *config.Config, nodeMetrics := defaultMetricsProvider(cfg.Instrumentation)(genDoc.ChainID) // Create the proxyApp and establish connections to the ABCI app (consensus, mempool, query). - proxyApp, err := createAndStartProxyAppConns(clientCreator, logger, nodeMetrics.proxy) + proxyApp, err := createAndStartProxyAppConns(ctx, clientCreator, logger, nodeMetrics.proxy) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) } @@ -166,12 +177,13 @@ func makeNode(cfg *config.Config, // we might need to index the txs of the replayed block as this might not have happened // when the node stopped last time (i.e. the node stopped after it saved the block // but before it indexed the txs, or, endblocker panicked) - eventBus, err := createAndStartEventBus(logger) + eventBus, err := createAndStartEventBus(ctx, logger) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) } - indexerService, eventSinks, err := createAndStartIndexerService(cfg, dbProvider, eventBus, + indexerService, eventSinks, err := createAndStartIndexerService( + ctx, cfg, dbProvider, eventBus, logger, genDoc.ChainID, nodeMetrics.indexer) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) @@ -184,14 +196,19 @@ func makeNode(cfg *config.Config, // FIXME: we should start services inside OnStart switch protocol { case "grpc": - privValidator, err = createAndStartPrivValidatorGRPCClient(cfg, genDoc.ChainID, logger) + privValidator, err = createAndStartPrivValidatorGRPCClient(ctx, cfg, genDoc.ChainID, logger) if err != nil { return nil, combineCloseError( fmt.Errorf("error with private validator grpc client: %w", err), makeCloser(closers)) } default: - privValidator, err = createAndStartPrivValidatorSocketClient(cfg.PrivValidator.ListenAddr, genDoc.ChainID, logger) + privValidator, err = createAndStartPrivValidatorSocketClient( + ctx, + cfg.PrivValidator.ListenAddr, + genDoc.ChainID, + logger, + ) if err != nil { return nil, combineCloseError( fmt.Errorf("error with private validator socket client: %w", err), @@ -201,7 +218,7 @@ func makeNode(cfg *config.Config, } var pubKey crypto.PubKey if cfg.Mode == config.ModeValidator { - pubKey, err = privValidator.GetPubKey(context.TODO()) + pubKey, err = privValidator.GetPubKey(ctx) if err != nil { return nil, combineCloseError(fmt.Errorf("can't get pubkey: %w", err), makeCloser(closers)) @@ -227,7 +244,7 @@ func makeNode(cfg *config.Config, if err := consensus.NewHandshaker( logger.With("module", "handshaker"), stateStore, state, blockStore, eventBus, genDoc, - ).Handshake(proxyApp); err != nil { + ).Handshake(ctx, proxyApp); err != nil { return nil, combineCloseError(err, makeCloser(closers)) } @@ -253,7 +270,6 @@ func makeNode(cfg *config.Config, nodeInfo, err := makeNodeInfo(cfg, nodeKey, eventSinks, genDoc, state) if err != nil { return nil, combineCloseError(err, makeCloser(closers)) - } peerManager, peerCloser, err := createPeerManager(cfg, dbProvider, nodeKey.ID) @@ -492,7 +508,7 @@ func makeSeedNode(cfg *config.Config, } // OnStart starts the Node. It implements service.Service. -func (n *nodeImpl) OnStart() error { +func (n *nodeImpl) OnStart(ctx context.Context) error { if n.config.RPC.PprofListenAddress != "" { // this service is not cleaned up (I believe that we'd // need to have another thread and a potentially a @@ -513,7 +529,7 @@ func (n *nodeImpl) OnStart() error { // Start the RPC server before the P2P server // so we can eg. receive txs for the first block if n.config.RPC.ListenAddress != "" && n.config.Mode != config.ModeSeed { - listeners, err := n.startRPC() + listeners, err := n.startRPC(ctx) if err != nil { return err } @@ -526,39 +542,39 @@ func (n *nodeImpl) OnStart() error { } // Start the transport. - if err := n.router.Start(); err != nil { + if err := n.router.Start(ctx); err != nil { return err } n.isListening = true if n.config.Mode != config.ModeSeed { - if err := n.bcReactor.Start(); err != nil { + if err := n.bcReactor.Start(ctx); err != nil { return err } // Start the real consensus reactor separately since the switch uses the shim. - if err := n.consensusReactor.Start(); err != nil { + if err := n.consensusReactor.Start(ctx); err != nil { return err } // Start the real state sync reactor separately since the switch uses the shim. - if err := n.stateSyncReactor.Start(); err != nil { + if err := n.stateSyncReactor.Start(ctx); err != nil { return err } // Start the real mempool reactor separately since the switch uses the shim. - if err := n.mempoolReactor.Start(); err != nil { + if err := n.mempoolReactor.Start(ctx); err != nil { return err } // Start the real evidence reactor separately since the switch uses the shim. - if err := n.evidenceReactor.Start(); err != nil { + if err := n.evidenceReactor.Start(ctx); err != nil { return err } } if n.config.P2P.PexReactor { - if err := n.pexReactor.Start(); err != nil { + if err := n.pexReactor.Start(ctx); err != nil { return err } } @@ -591,7 +607,7 @@ func (n *nodeImpl) OnStart() error { // bubbling up the error and gracefully shutting down the rest of the node go func() { n.Logger.Info("starting state sync") - state, err := n.stateSyncReactor.Sync(context.TODO()) + state, err := n.stateSyncReactor.Sync(ctx) if err != nil { n.Logger.Error("state sync failed; shutting down this node", "err", err) // stop the node @@ -617,7 +633,7 @@ func (n *nodeImpl) OnStart() error { // is running // FIXME Very ugly to have these metrics bleed through here. n.consensusReactor.SetBlockSyncingMetrics(1) - if err := bcR.SwitchToBlockSync(state); err != nil { + if err := bcR.SwitchToBlockSync(ctx, state); err != nil { n.Logger.Error("failed to switch to block sync", "err", err) return } @@ -638,19 +654,13 @@ func (n *nodeImpl) OnStart() error { // OnStop stops the Node. It implements service.Service. func (n *nodeImpl) OnStop() { - n.Logger.Info("Stopping Node") if n.eventBus != nil { - // first stop the non-reactor services - if err := n.eventBus.Stop(); err != nil { - n.Logger.Error("Error closing eventBus", "err", err) - } + n.eventBus.Wait() } if n.indexerService != nil { - if err := n.indexerService.Stop(); err != nil { - n.Logger.Error("Error closing indexerService", "err", err) - } + n.indexerService.Wait() } for _, es := range n.eventSinks { @@ -660,41 +670,14 @@ func (n *nodeImpl) OnStop() { } if n.config.Mode != config.ModeSeed { - // now stop the reactors - - // Stop the real blockchain reactor separately since the switch uses the shim. - if err := n.bcReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the blockchain reactor", "err", err) - } - - // Stop the real consensus reactor separately since the switch uses the shim. - if err := n.consensusReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the consensus reactor", "err", err) - } - - // Stop the real state sync reactor separately since the switch uses the shim. - if err := n.stateSyncReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the state sync reactor", "err", err) - } - - // Stop the real mempool reactor separately since the switch uses the shim. - if err := n.mempoolReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the mempool reactor", "err", err) - } - - // Stop the real evidence reactor separately since the switch uses the shim. - if err := n.evidenceReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the evidence reactor", "err", err) - } - } - - if err := n.pexReactor.Stop(); err != nil { - n.Logger.Error("failed to stop the PEX v2 reactor", "err", err) - } - - if err := n.router.Stop(); err != nil { - n.Logger.Error("failed to stop router", "err", err) + n.bcReactor.Wait() + n.consensusReactor.Wait() + n.stateSyncReactor.Wait() + n.mempoolReactor.Wait() + n.evidenceReactor.Wait() } + n.pexReactor.Wait() + n.router.Wait() n.isListening = false // finally stop the listeners / external services @@ -706,9 +689,7 @@ func (n *nodeImpl) OnStop() { } if pvsc, ok := n.privValidator.(service.Service); ok { - if err := pvsc.Stop(); err != nil { - n.Logger.Error("Error closing private validator", "err", err) - } + pvsc.Wait() } if n.prometheusSrv != nil { @@ -719,13 +700,15 @@ func (n *nodeImpl) OnStop() { } if err := n.shutdownOps(); err != nil { - n.Logger.Error("problem shutting down additional services", "err", err) + if strings.TrimSpace(err.Error()) != "" { + n.Logger.Error("problem shutting down additional services", "err", err) + } } } -func (n *nodeImpl) startRPC() ([]net.Listener, error) { +func (n *nodeImpl) startRPC(ctx context.Context) ([]net.Listener, error) { if n.config.Mode == config.ModeValidator { - pubKey, err := n.privValidator.GetPubKey(context.TODO()) + pubKey, err := n.privValidator.GetPubKey(ctx) if pubKey == nil || err != nil { return nil, fmt.Errorf("can't get pubkey: %w", err) } @@ -970,8 +953,8 @@ func loadStateFromDBOrGenesisDocProvider( } func createAndStartPrivValidatorSocketClient( - listenAddr, - chainID string, + ctx context.Context, + listenAddr, chainID string, logger log.Logger, ) (types.PrivValidator, error) { @@ -980,13 +963,13 @@ func createAndStartPrivValidatorSocketClient( return nil, fmt.Errorf("failed to start private validator: %w", err) } - pvsc, err := privval.NewSignerClient(pve, chainID) + pvsc, err := privval.NewSignerClient(ctx, pve, chainID) if err != nil { return nil, fmt.Errorf("failed to start private validator: %w", err) } // try to get a pubkey from private validate first time - _, err = pvsc.GetPubKey(context.TODO()) + _, err = pvsc.GetPubKey(ctx) if err != nil { return nil, fmt.Errorf("can't get pubkey: %w", err) } @@ -1001,6 +984,7 @@ func createAndStartPrivValidatorSocketClient( } func createAndStartPrivValidatorGRPCClient( + ctx context.Context, cfg *config.Config, chainID string, logger log.Logger, @@ -1016,7 +1000,7 @@ func createAndStartPrivValidatorGRPCClient( } // try to get a pubkey from private validate first time - _, err = pvsc.GetPubKey(context.TODO()) + _, err = pvsc.GetPubKey(ctx) if err != nil { return nil, fmt.Errorf("can't get pubkey: %w", err) } @@ -1031,7 +1015,7 @@ func getRouterConfig(conf *config.Config, proxyApp proxy.AppConns) p2p.RouterOpt if conf.FilterPeers && proxyApp != nil { opts.FilterPeerByID = func(ctx context.Context, id types.NodeID) error { - res, err := proxyApp.Query().QuerySync(context.Background(), abci.RequestQuery{ + res, err := proxyApp.Query().QuerySync(ctx, abci.RequestQuery{ Path: fmt.Sprintf("/p2p/filter/id/%s", id), }) if err != nil { diff --git a/node/node_test.go b/node/node_test.go index e9c2159ed..407bf93ea 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -7,7 +7,6 @@ import ( "math" "net" "os" - "syscall" "testing" "time" @@ -43,14 +42,17 @@ func TestNodeStartStop(t *testing.T) { defer os.RemoveAll(cfg.RootDir) + ctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + // create & start node - ns, err := newDefaultNode(cfg, log.TestingLogger()) + ns, err := newDefaultNode(ctx, cfg, log.TestingLogger()) require.NoError(t, err) - require.NoError(t, ns.Start()) + require.NoError(t, ns.Start(ctx)) t.Cleanup(func() { if ns.IsRunning() { - assert.NoError(t, ns.Stop()) + bcancel() ns.Wait() } }) @@ -58,9 +60,6 @@ func TestNodeStartStop(t *testing.T) { n, ok := ns.(*nodeImpl) require.True(t, ok) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // wait for the node to produce a block blocksSub, err := n.EventBus().SubscribeWithArgs(ctx, pubsub.SubscribeArgs{ ClientID: "node_test", @@ -75,35 +74,35 @@ func TestNodeStartStop(t *testing.T) { // stop the node go func() { - err = n.Stop() - require.NoError(t, err) + bcancel() + n.Wait() }() select { case <-n.Quit(): - case <-time.After(5 * time.Second): - pid := os.Getpid() - p, err := os.FindProcess(pid) - if err != nil { - panic(err) + return + case <-time.After(10 * time.Second): + if n.IsRunning() { + t.Fatal("timed out waiting for shutdown") } - err = p.Signal(syscall.SIGABRT) - fmt.Println(err) - t.Fatal("timed out waiting for shutdown") + } } -func getTestNode(t *testing.T, conf *config.Config, logger log.Logger) *nodeImpl { +func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger log.Logger) *nodeImpl { t.Helper() - ns, err := newDefaultNode(conf, logger) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ns, err := newDefaultNode(ctx, conf, logger) require.NoError(t, err) n, ok := ns.(*nodeImpl) require.True(t, ok) t.Cleanup(func() { - if ns.IsRunning() { - assert.NoError(t, ns.Stop()) + cancel() + if n.IsRunning() { ns.Wait() } }) @@ -118,11 +117,14 @@ func TestNodeDelayedStart(t *testing.T) { defer os.RemoveAll(cfg.RootDir) now := tmtime.Now() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // create & start node - n := getTestNode(t, cfg, log.TestingLogger()) + n := getTestNode(ctx, t, cfg, log.TestingLogger()) n.GenesisDoc().GenesisTime = now.Add(2 * time.Second) - require.NoError(t, n.Start()) + require.NoError(t, n.Start(ctx)) startTime := tmtime.Now() assert.Equal(t, true, startTime.After(n.GenesisDoc().GenesisTime)) @@ -133,8 +135,11 @@ func TestNodeSetAppVersion(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // create node - n := getTestNode(t, cfg, log.TestingLogger()) + n := getTestNode(ctx, t, cfg, log.TestingLogger()) // default config uses the kvstore app appVersion := kvstore.ProtocolVersion @@ -151,6 +156,9 @@ func TestNodeSetAppVersion(t *testing.T) { func TestNodeSetPrivValTCP(t *testing.T) { addr := "tcp://" + testFreeAddr(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("node_priv_val_tcp_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) @@ -170,31 +178,34 @@ func TestNodeSetPrivValTCP(t *testing.T) { ) go func() { - err := signerServer.Start() + err := signerServer.Start(ctx) if err != nil { panic(err) } }() defer signerServer.Stop() //nolint:errcheck // ignore for tests - n := getTestNode(t, cfg, log.TestingLogger()) + n := getTestNode(ctx, t, cfg, log.TestingLogger()) assert.IsType(t, &privval.RetrySignerClient{}, n.PrivValidator()) } // address without a protocol must result in error func TestPrivValidatorListenAddrNoProtocol(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + addrNoPrefix := testFreeAddr(t) cfg, err := config.ResetTestRoot("node_priv_val_tcp_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) cfg.PrivValidator.ListenAddr = addrNoPrefix + n, err := newDefaultNode(ctx, cfg, log.TestingLogger()) - n, err := newDefaultNode(cfg, log.TestingLogger()) assert.Error(t, err) if n != nil && n.IsRunning() { - assert.NoError(t, n.Stop()) + cancel() n.Wait() } } @@ -203,6 +214,9 @@ func TestNodeSetPrivValIPC(t *testing.T) { tmpfile := "/tmp/kms." + tmrand.Str(6) + ".sock" defer os.Remove(tmpfile) // clean up + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := config.ResetTestRoot("node_priv_val_tcp_test") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) @@ -222,11 +236,11 @@ func TestNodeSetPrivValIPC(t *testing.T) { ) go func() { - err := pvsc.Start() + err := pvsc.Start(ctx) require.NoError(t, err) }() defer pvsc.Stop() //nolint:errcheck // ignore for tests - n := getTestNode(t, cfg, log.TestingLogger()) + n := getTestNode(ctx, t, cfg, log.TestingLogger()) assert.IsType(t, &privval.RetrySignerClient{}, n.PrivValidator()) } @@ -248,11 +262,11 @@ func TestCreateProposalBlock(t *testing.T) { cfg, err := config.ResetTestRoot("node_create_proposal") require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) + cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, log.TestingLogger(), proxy.NopMetrics()) - err = proxyApp.Start() + err = proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests logger := log.TestingLogger() @@ -344,9 +358,8 @@ func TestMaxTxsProposalBlockSize(t *testing.T) { defer os.RemoveAll(cfg.RootDir) cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, log.TestingLogger(), proxy.NopMetrics()) - err = proxyApp.Start() + err = proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests logger := log.TestingLogger() @@ -408,9 +421,8 @@ func TestMaxProposalBlockSize(t *testing.T) { defer os.RemoveAll(cfg.RootDir) cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, log.TestingLogger(), proxy.NopMetrics()) - err = proxyApp.Start() + err = proxyApp.Start(ctx) require.Nil(t, err) - defer proxyApp.Stop() //nolint:errcheck // ignore for tests logger := log.TestingLogger() @@ -432,7 +444,7 @@ func TestMaxProposalBlockSize(t *testing.T) { // fill the mempool with one txs just below the maximum size txLength := int(types.MaxDataBytesNoEvidence(maxBytes, types.MaxVotesCount)) tx := tmrand.Bytes(txLength - 6) // to account for the varint - err = mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) + err = mp.CheckTx(ctx, tx, nil, mempool.TxInfo{}) assert.NoError(t, err) // now produce more txs than what a normal block can hold with 10 smaller txs // At the end of the test, only the single big tx should be added @@ -521,6 +533,9 @@ func TestNodeNewSeedNode(t *testing.T) { cfg.Mode = config.ModeSeed defer os.RemoveAll(cfg.RootDir) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) require.NoError(t, err) @@ -530,17 +545,20 @@ func TestNodeNewSeedNode(t *testing.T) { defaultGenesisDocProviderFunc(cfg), log.TestingLogger(), ) + t.Cleanup(ns.Wait) require.NoError(t, err) n, ok := ns.(*nodeImpl) require.True(t, ok) - err = n.Start() + err = n.Start(ctx) require.NoError(t, err) assert.True(t, n.pexReactor.IsRunning()) - require.NoError(t, n.Stop()) + cancel() + n.Wait() + assert.False(t, n.pexReactor.IsRunning()) } func TestNodeSetEventSink(t *testing.T) { @@ -549,19 +567,22 @@ func TestNodeSetEventSink(t *testing.T) { defer os.RemoveAll(cfg.RootDir) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.TestingLogger() setupTest := func(t *testing.T, conf *config.Config) []indexer.EventSink { - eventBus, err := createAndStartEventBus(logger) + eventBus, err := createAndStartEventBus(ctx, logger) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, eventBus.Stop()) }) + t.Cleanup(eventBus.Wait) genDoc, err := types.GenesisDocFromFile(cfg.GenesisFile()) require.NoError(t, err) - indexService, eventSinks, err := createAndStartIndexerService(cfg, + indexService, eventSinks, err := createAndStartIndexerService(ctx, cfg, config.DefaultDBProvider, eventBus, logger, genDoc.ChainID, indexer.NopMetrics()) require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, indexService.Stop()) }) + t.Cleanup(indexService.Wait) return eventSinks } cleanup := func(ns service.Service) func() { @@ -576,7 +597,7 @@ func TestNodeSetEventSink(t *testing.T) { if !n.IsRunning() { return } - assert.NoError(t, n.Stop()) + cancel() n.Wait() } } @@ -598,7 +619,7 @@ func TestNodeSetEventSink(t *testing.T) { assert.Equal(t, indexer.NULL, eventSinks[0].Type()) cfg.TxIndex.Indexer = []string{"kvv"} - ns, err := newDefaultNode(cfg, logger) + ns, err := newDefaultNode(ctx, cfg, logger) assert.Nil(t, ns) assert.Contains(t, err.Error(), "unsupported event sink type") t.Cleanup(cleanup(ns)) @@ -610,7 +631,7 @@ func TestNodeSetEventSink(t *testing.T) { assert.Equal(t, indexer.NULL, eventSinks[0].Type()) cfg.TxIndex.Indexer = []string{"psql"} - ns, err = newDefaultNode(cfg, logger) + ns, err = newDefaultNode(ctx, cfg, logger) assert.Nil(t, ns) assert.Contains(t, err.Error(), "the psql connection settings cannot be empty") t.Cleanup(cleanup(ns)) @@ -652,14 +673,14 @@ func TestNodeSetEventSink(t *testing.T) { var e = errors.New("found duplicated sinks, please check the tx-index section in the config.toml") cfg.TxIndex.Indexer = []string{"psql", "kv", "Kv"} cfg.TxIndex.PsqlConn = psqlConn - ns, err = newDefaultNode(cfg, logger) + ns, err = newDefaultNode(ctx, cfg, logger) require.Error(t, err) assert.Contains(t, err.Error(), e.Error()) t.Cleanup(cleanup(ns)) cfg.TxIndex.Indexer = []string{"Psql", "kV", "kv", "pSql"} cfg.TxIndex.PsqlConn = psqlConn - ns, err = newDefaultNode(cfg, logger) + ns, err = newDefaultNode(ctx, cfg, logger) require.Error(t, err) assert.Contains(t, err.Error(), e.Error()) t.Cleanup(cleanup(ns)) diff --git a/node/public.go b/node/public.go index c616eebac..87007bdfc 100644 --- a/node/public.go +++ b/node/public.go @@ -2,6 +2,7 @@ package node import ( + "context" "fmt" abciclient "github.com/tendermint/tendermint/abci/client" @@ -16,8 +17,12 @@ import ( // process that host their own process-local tendermint node. This is // equivalent to running tendermint in it's own process communicating // to an external ABCI application. -func NewDefault(conf *config.Config, logger log.Logger) (service.Service, error) { - return newDefaultNode(conf, logger) +func NewDefault( + ctx context.Context, + conf *config.Config, + logger log.Logger, +) (service.Service, error) { + return newDefaultNode(ctx, conf, logger) } // New constructs a tendermint node. The ClientCreator makes it @@ -26,7 +31,9 @@ func NewDefault(conf *config.Config, logger log.Logger) (service.Service, error) // Genesis document: if the value is nil, the genesis document is read // from the file specified in the config, and otherwise the node uses // value of the final argument. -func New(conf *config.Config, +func New( + ctx context.Context, + conf *config.Config, logger log.Logger, cf abciclient.Creator, gen *types.GenesisDoc, @@ -51,7 +58,9 @@ func New(conf *config.Config, return nil, err } - return makeNode(conf, + return makeNode( + ctx, + conf, pval, nodeKey, cf, diff --git a/node/setup.go b/node/setup.go index 297ed0265..6ca991484 100644 --- a/node/setup.go +++ b/node/setup.go @@ -2,6 +2,7 @@ package node import ( "bytes" + "context" "errors" "fmt" "strings" @@ -52,6 +53,10 @@ func makeCloser(cs []closer) closer { } } +func convertCancelCloser(cancel context.CancelFunc) closer { + return func() error { cancel(); return nil } +} + func combineCloseError(err error, cl closer) error { if err == nil { return cl() @@ -88,26 +93,31 @@ func initDBs( return blockStore, stateDB, makeCloser(closers), nil } -// nolint:lll -func createAndStartProxyAppConns(clientCreator abciclient.Creator, logger log.Logger, metrics *proxy.Metrics) (proxy.AppConns, error) { +func createAndStartProxyAppConns( + ctx context.Context, + clientCreator abciclient.Creator, + logger log.Logger, + metrics *proxy.Metrics, +) (proxy.AppConns, error) { proxyApp := proxy.NewAppConns(clientCreator, logger.With("module", "proxy"), metrics) - if err := proxyApp.Start(); err != nil { + if err := proxyApp.Start(ctx); err != nil { return nil, fmt.Errorf("error starting proxy app connections: %v", err) } return proxyApp, nil } -func createAndStartEventBus(logger log.Logger) (*eventbus.EventBus, error) { +func createAndStartEventBus(ctx context.Context, logger log.Logger) (*eventbus.EventBus, error) { eventBus := eventbus.NewDefault(logger.With("module", "events")) - if err := eventBus.Start(); err != nil { + if err := eventBus.Start(ctx); err != nil { return nil, err } return eventBus, nil } func createAndStartIndexerService( + ctx context.Context, cfg *config.Config, dbProvider config.DBProvider, eventBus *eventbus.EventBus, @@ -127,7 +137,7 @@ func createAndStartIndexerService( Metrics: metrics, }) - if err := indexerService.Start(); err != nil { + if err := indexerService.Start(ctx); err != nil { return nil, nil, err } diff --git a/privval/grpc/util.go b/privval/grpc/util.go index 75ad04d42..7e0483f9c 100644 --- a/privval/grpc/util.go +++ b/privval/grpc/util.go @@ -109,7 +109,7 @@ func DialRemoteSigner( dialOptions = append(dialOptions, transportSecurity) - ctx := context.Background() + ctx := context.TODO() _, address := tmnet.ProtocolAndAddress(cfg.ListenAddr) conn, err := grpc.DialContext(ctx, address, dialOptions...) if err != nil { diff --git a/privval/signer_client.go b/privval/signer_client.go index 5e5b32a92..ec6d95ca6 100644 --- a/privval/signer_client.go +++ b/privval/signer_client.go @@ -23,9 +23,9 @@ var _ types.PrivValidator = (*SignerClient)(nil) // NewSignerClient returns an instance of SignerClient. // it will start the endpoint (if not already started) -func NewSignerClient(endpoint *SignerListenerEndpoint, chainID string) (*SignerClient, error) { +func NewSignerClient(ctx context.Context, endpoint *SignerListenerEndpoint, chainID string) (*SignerClient, error) { if !endpoint.IsRunning() { - if err := endpoint.Start(); err != nil { + if err := endpoint.Start(ctx); err != nil { return nil, fmt.Errorf("failed to start listener endpoint: %w", err) } } diff --git a/privval/signer_client_test.go b/privval/signer_client_test.go index 9aa49e709..f9272b004 100644 --- a/privval/signer_client_test.go +++ b/privval/signer_client_test.go @@ -23,370 +23,336 @@ type signerTestCase struct { mockPV types.PrivValidator signerClient *SignerClient signerServer *SignerServer + name string + closer context.CancelFunc } -func getSignerTestCases(t *testing.T) []signerTestCase { +func getSignerTestCases(ctx context.Context, t *testing.T) []signerTestCase { + t.Helper() + testCases := make([]signerTestCase, 0) // Get test cases for each possible dialer (DialTCP / DialUnix / etc) - for _, dtc := range getDialerTestCases(t) { + for idx, dtc := range getDialerTestCases(t) { chainID := tmrand.Str(12) mockPV := types.NewMockPV() + cctx, ccancel := context.WithCancel(ctx) // get a pair of signer listener, signer dialer endpoints - sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer) - sc, err := NewSignerClient(sl, chainID) + sl, sd := getMockEndpoints(cctx, t, dtc.addr, dtc.dialer) + sc, err := NewSignerClient(cctx, sl, chainID) require.NoError(t, err) ss := NewSignerServer(sd, chainID, mockPV) - err = ss.Start() - require.NoError(t, err) + require.NoError(t, ss.Start(cctx)) - tc := signerTestCase{ + testCases = append(testCases, signerTestCase{ + name: fmt.Sprintf("Case%d%T_%s", idx, dtc.dialer, chainID), + closer: ccancel, chainID: chainID, mockPV: mockPV, signerClient: sc, signerServer: ss, - } - - testCases = append(testCases, tc) + }) } return testCases } func TestSignerClose(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - err := tc.signerClient.Close() - assert.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - err = tc.signerServer.Stop() - assert.NoError(t, err) + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() + + assert.NoError(t, tc.signerClient.Close()) + assert.NoError(t, tc.signerServer.Stop()) + }) } } func TestSignerPing(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) - } - }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) - } - }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for _, tc := range getSignerTestCases(ctx, t) { err := tc.signerClient.Ping() assert.NoError(t, err) } } func TestSignerGetPubKey(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() + + pubKey, err := tc.signerClient.GetPubKey(ctx) + require.NoError(t, err) + expectedPubKey, err := tc.mockPV.GetPubKey(ctx) + require.NoError(t, err) + + assert.Equal(t, expectedPubKey, pubKey) + + pubKey, err = tc.signerClient.GetPubKey(ctx) + require.NoError(t, err) + expectedpk, err := tc.mockPV.GetPubKey(ctx) + require.NoError(t, err) + expectedAddr := expectedpk.Address() + + assert.Equal(t, expectedAddr, pubKey.Address()) }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) - } - }) - - pubKey, err := tc.signerClient.GetPubKey(context.Background()) - require.NoError(t, err) - expectedPubKey, err := tc.mockPV.GetPubKey(context.Background()) - require.NoError(t, err) - - assert.Equal(t, expectedPubKey, pubKey) - - pubKey, err = tc.signerClient.GetPubKey(context.Background()) - require.NoError(t, err) - expectedpk, err := tc.mockPV.GetPubKey(context.Background()) - require.NoError(t, err) - expectedAddr := expectedpk.Address() - - assert.Equal(t, expectedAddr, pubKey.Address()) } } func TestSignerProposal(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - have := &types.Proposal{ - Type: tmproto.ProposalType, - Height: 1, - Round: 2, - POLRound: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - } - want := &types.Proposal{ - Type: tmproto.ProposalType, - Height: 1, - Round: 2, - POLRound: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() + + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + have := &types.Proposal{ + Type: tmproto.ProposalType, + Height: 1, + Round: 2, + POLRound: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, } - }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) + want := &types.Proposal{ + Type: tmproto.ProposalType, + Height: 1, + Round: 2, + POLRound: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, } + + require.NoError(t, tc.mockPV.SignProposal(ctx, tc.chainID, want.ToProto())) + require.NoError(t, tc.signerClient.SignProposal(ctx, tc.chainID, have.ToProto())) + + assert.Equal(t, want.Signature, have.Signature) }) - require.NoError(t, tc.mockPV.SignProposal(context.Background(), tc.chainID, want.ToProto())) - require.NoError(t, tc.signerClient.SignProposal(context.Background(), tc.chainID, have.ToProto())) - - assert.Equal(t, want.Signature, have.Signature) } } func TestSignerVote(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - valAddr := tmrand.Bytes(crypto.AddressSize) - want := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - have := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + valAddr := tmrand.Bytes(crypto.AddressSize) + want := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } - }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) + + have := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } + + require.NoError(t, tc.mockPV.SignVote(ctx, tc.chainID, want.ToProto())) + require.NoError(t, tc.signerClient.SignVote(ctx, tc.chainID, have.ToProto())) + + assert.Equal(t, want.Signature, have.Signature) }) - - require.NoError(t, tc.mockPV.SignVote(context.Background(), tc.chainID, want.ToProto())) - require.NoError(t, tc.signerClient.SignVote(context.Background(), tc.chainID, have.ToProto())) - - assert.Equal(t, want.Signature, have.Signature) } } func TestSignerVoteResetDeadline(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - valAddr := tmrand.Bytes(crypto.AddressSize) - want := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - have := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } - - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + valAddr := tmrand.Bytes(crypto.AddressSize) + want := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } - }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) + + have := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } + + time.Sleep(testTimeoutReadWrite2o3) + + require.NoError(t, tc.mockPV.SignVote(ctx, tc.chainID, want.ToProto())) + require.NoError(t, tc.signerClient.SignVote(ctx, tc.chainID, have.ToProto())) + assert.Equal(t, want.Signature, have.Signature) + + // TODO(jleni): Clarify what is actually being tested + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(testTimeoutReadWrite2o3) + + require.NoError(t, tc.mockPV.SignVote(ctx, tc.chainID, want.ToProto())) + require.NoError(t, tc.signerClient.SignVote(ctx, tc.chainID, have.ToProto())) + assert.Equal(t, want.Signature, have.Signature) }) - - time.Sleep(testTimeoutReadWrite2o3) - - require.NoError(t, tc.mockPV.SignVote(context.Background(), tc.chainID, want.ToProto())) - require.NoError(t, tc.signerClient.SignVote(context.Background(), tc.chainID, have.ToProto())) - assert.Equal(t, want.Signature, have.Signature) - - // TODO(jleni): Clarify what is actually being tested - - // This would exceed the deadline if it was not extended by the previous message - time.Sleep(testTimeoutReadWrite2o3) - - require.NoError(t, tc.mockPV.SignVote(context.Background(), tc.chainID, want.ToProto())) - require.NoError(t, tc.signerClient.SignVote(context.Background(), tc.chainID, have.ToProto())) - assert.Equal(t, want.Signature, have.Signature) } } func TestSignerVoteKeepAlive(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - valAddr := tmrand.Bytes(crypto.AddressSize) - want := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - have := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - } + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + valAddr := tmrand.Bytes(crypto.AddressSize) + want := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } - }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) + + have := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, } + + // Check that even if the client does not request a + // signature for a long time. The service is still available + + // in this particular case, we use the dialer logger to ensure that + // test messages are properly interleaved in the test logs + tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------") + time.Sleep(testTimeoutReadWrite * 3) + tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------") + + require.NoError(t, tc.mockPV.SignVote(ctx, tc.chainID, want.ToProto())) + require.NoError(t, tc.signerClient.SignVote(ctx, tc.chainID, have.ToProto())) + + assert.Equal(t, want.Signature, have.Signature) }) - - // Check that even if the client does not request a - // signature for a long time. The service is still available - - // in this particular case, we use the dialer logger to ensure that - // test messages are properly interleaved in the test logs - tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------") - time.Sleep(testTimeoutReadWrite * 3) - tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------") - - require.NoError(t, tc.mockPV.SignVote(context.Background(), tc.chainID, want.ToProto())) - require.NoError(t, tc.signerClient.SignVote(context.Background(), tc.chainID, have.ToProto())) - - assert.Equal(t, want.Signature, have.Signature) } } func TestSignerSignProposalErrors(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - // Replace service with a mock that always fails - tc.signerServer.privVal = types.NewErroringMockPV() - tc.mockPV = types.NewErroringMockPV() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() + // Replace service with a mock that always fails + tc.signerServer.privVal = types.NewErroringMockPV() + tc.mockPV = types.NewErroringMockPV() + + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + proposal := &types.Proposal{ + Type: tmproto.ProposalType, + Height: 1, + Round: 2, + POLRound: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + Signature: []byte("signature"), } + + err := tc.signerClient.SignProposal(ctx, tc.chainID, proposal.ToProto()) + rserr, ok := err.(*RemoteSignerError) + require.True(t, ok, "%T", err) + require.Contains(t, rserr.Error(), types.ErroringMockPVErr.Error()) + + err = tc.mockPV.SignProposal(ctx, tc.chainID, proposal.ToProto()) + require.Error(t, err) + + err = tc.signerClient.SignProposal(ctx, tc.chainID, proposal.ToProto()) + require.Error(t, err) }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) - } - }) - - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - proposal := &types.Proposal{ - Type: tmproto.ProposalType, - Height: 1, - Round: 2, - POLRound: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - Signature: []byte("signature"), - } - - err := tc.signerClient.SignProposal(context.Background(), tc.chainID, proposal.ToProto()) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = tc.mockPV.SignProposal(context.Background(), tc.chainID, proposal.ToProto()) - require.Error(t, err) - - err = tc.signerClient.SignProposal(context.Background(), tc.chainID, proposal.ToProto()) - require.Error(t, err) } } func TestSignerSignVoteErrors(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - ts := time.Now() - hash := tmrand.Bytes(tmhash.Size) - valAddr := tmrand.Bytes(crypto.AddressSize) - vote := &types.Vote{ - Type: tmproto.PrecommitType, - Height: 1, - Round: 2, - BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, - Timestamp: ts, - ValidatorAddress: valAddr, - ValidatorIndex: 1, - Signature: []byte("signature"), - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - // Replace signer service privval with one that always fails - tc.signerServer.privVal = types.NewErroringMockPV() - tc.mockPV = types.NewErroringMockPV() + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) + ts := time.Now() + hash := tmrand.Bytes(tmhash.Size) + valAddr := tmrand.Bytes(crypto.AddressSize) + vote := &types.Vote{ + Type: tmproto.PrecommitType, + Height: 1, + Round: 2, + BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, + Timestamp: ts, + ValidatorAddress: valAddr, + ValidatorIndex: 1, + Signature: []byte("signature"), } + + // Replace signer service privval with one that always fails + tc.signerServer.privVal = types.NewErroringMockPV() + tc.mockPV = types.NewErroringMockPV() + + err := tc.signerClient.SignVote(ctx, tc.chainID, vote.ToProto()) + rserr, ok := err.(*RemoteSignerError) + require.True(t, ok, "%T", err) + require.Contains(t, rserr.Error(), types.ErroringMockPVErr.Error()) + + err = tc.mockPV.SignVote(ctx, tc.chainID, vote.ToProto()) + require.Error(t, err) + + err = tc.signerClient.SignVote(ctx, tc.chainID, vote.ToProto()) + require.Error(t, err) }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) - } - }) - - err := tc.signerClient.SignVote(context.Background(), tc.chainID, vote.ToProto()) - require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) - - err = tc.mockPV.SignVote(context.Background(), tc.chainID, vote.ToProto()) - require.Error(t, err) - - err = tc.signerClient.SignVote(context.Background(), tc.chainID, vote.ToProto()) - require.Error(t, err) } } @@ -413,28 +379,23 @@ func brokenHandler(ctx context.Context, privVal types.PrivValidator, request pri } func TestSignerUnexpectedResponse(t *testing.T) { - for _, tc := range getSignerTestCases(t) { - tc.signerServer.privVal = types.NewMockPV() - tc.mockPV = types.NewMockPV() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - tc.signerServer.SetRequestHandler(brokenHandler) + for _, tc := range getSignerTestCases(ctx, t) { + t.Run(tc.name, func(t *testing.T) { + defer tc.closer() - tc := tc - t.Cleanup(func() { - if err := tc.signerServer.Stop(); err != nil { - t.Error(err) - } + tc.signerServer.privVal = types.NewMockPV() + tc.mockPV = types.NewMockPV() + + tc.signerServer.SetRequestHandler(brokenHandler) + + ts := time.Now() + want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType} + + e := tc.signerClient.SignVote(ctx, tc.chainID, want.ToProto()) + assert.EqualError(t, e, "empty response") }) - t.Cleanup(func() { - if err := tc.signerClient.Close(); err != nil { - t.Error(err) - } - }) - - ts := time.Now() - want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType} - - e := tc.signerClient.SignVote(context.Background(), tc.chainID, want.ToProto()) - assert.EqualError(t, e, "empty response") } } diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 292e7a476..e2287c630 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -1,6 +1,7 @@ package privval import ( + "context" "fmt" "net" "time" @@ -63,7 +64,7 @@ func NewSignerListenerEndpoint( } // OnStart implements service.Service. -func (sl *SignerListenerEndpoint) OnStart() error { +func (sl *SignerListenerEndpoint) OnStart(ctx context.Context) error { sl.connectRequestCh = make(chan struct{}) sl.connectionAvailableCh = make(chan net.Conn) diff --git a/privval/signer_listener_endpoint_test.go b/privval/signer_listener_endpoint_test.go index cbd45e6ce..b92e0abe5 100644 --- a/privval/signer_listener_endpoint_test.go +++ b/privval/signer_listener_endpoint_test.go @@ -1,6 +1,7 @@ package privval import ( + "context" "net" "testing" "time" @@ -38,6 +39,9 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { retries = 10 ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -71,7 +75,7 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { mockPV := types.NewMockPV() signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV) - err = signerServer.Start() + err = signerServer.Start(ctx) require.NoError(t, err) t.Cleanup(func() { if err := signerServer.Stop(); err != nil { @@ -88,6 +92,9 @@ func TestSignerRemoteRetryTCPOnly(t *testing.T) { } func TestRetryConnToRemoteSigner(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for _, tc := range getDialerTestCases(t) { var ( logger = log.TestingLogger() @@ -107,14 +114,9 @@ func TestRetryConnToRemoteSigner(t *testing.T) { signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV) - startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) - t.Cleanup(func() { - if err := listenerEndpoint.Stop(); err != nil { - t.Error(err) - } - }) + startListenerEndpointAsync(ctx, t, listenerEndpoint, endpointIsOpenCh) - require.NoError(t, signerServer.Start()) + require.NoError(t, signerServer.Start(ctx)) assert.True(t, signerServer.IsRunning()) <-endpointIsOpenCh if err := signerServer.Stop(); err != nil { @@ -128,13 +130,8 @@ func TestRetryConnToRemoteSigner(t *testing.T) { signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV) // let some pings pass - require.NoError(t, signerServer2.Start()) + require.NoError(t, signerServer2.Start(ctx)) assert.True(t, signerServer2.IsRunning()) - t.Cleanup(func() { - if err := signerServer2.Stop(); err != nil { - t.Error(err) - } - }) // give the client some time to re-establish the conn to the remote signer // should see sth like this in the logs: @@ -175,15 +172,23 @@ func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite ) } -func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) { +func startListenerEndpointAsync( + ctx context.Context, + t *testing.T, + sle *SignerListenerEndpoint, + endpointIsOpenCh chan struct{}, +) { + t.Helper() + go func(sle *SignerListenerEndpoint) { - require.NoError(t, sle.Start()) + require.NoError(t, sle.Start(ctx)) assert.True(t, sle.IsRunning()) close(endpointIsOpenCh) }(sle) } func getMockEndpoints( + ctx context.Context, t *testing.T, addr string, socketDialer SocketDialer, @@ -204,9 +209,9 @@ func getMockEndpoints( SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint) SignerDialerEndpointConnRetries(1e6)(dialerEndpoint) - startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh) + startListenerEndpointAsync(ctx, t, listenerEndpoint, endpointIsOpenCh) - require.NoError(t, dialerEndpoint.Start()) + require.NoError(t, dialerEndpoint.Start(ctx)) assert.True(t, dialerEndpoint.IsRunning()) <-endpointIsOpenCh diff --git a/privval/signer_server.go b/privval/signer_server.go index 24bf67cc5..e31d3bdb4 100644 --- a/privval/signer_server.go +++ b/privval/signer_server.go @@ -42,8 +42,8 @@ func NewSignerServer(endpoint *SignerDialerEndpoint, chainID string, privVal typ } // OnStart implements service.Service. -func (ss *SignerServer) OnStart() error { - go ss.serviceLoop() +func (ss *SignerServer) OnStart(ctx context.Context) error { + go ss.serviceLoop(ctx) return nil } @@ -91,18 +91,18 @@ func (ss *SignerServer) servicePendingRequest() { } } -func (ss *SignerServer) serviceLoop() { +func (ss *SignerServer) serviceLoop(ctx context.Context) { for { select { + case <-ss.Quit(): + return + case <-ctx.Done(): + return default: - err := ss.endpoint.ensureConnection() - if err != nil { + if err := ss.endpoint.ensureConnection(); err != nil { return } ss.servicePendingRequest() - - case <-ss.Quit(): - return } } } diff --git a/rpc/client/http/ws.go b/rpc/client/http/ws.go index 0f908e271..e4c2a14ed 100644 --- a/rpc/client/http/ws.go +++ b/rpc/client/http/ws.go @@ -92,7 +92,7 @@ func newWsEvents(remote string, wso WSOptions) (*wsEvents, error) { } // Start starts the websocket client and the event loop. -func (w *wsEvents) Start() error { +func (w *wsEvents) Start(ctx context.Context) error { if err := w.ws.Start(); err != nil { return err } diff --git a/rpc/client/interface.go b/rpc/client/interface.go index 474eb9937..8d160b799 100644 --- a/rpc/client/interface.go +++ b/rpc/client/interface.go @@ -35,7 +35,7 @@ type Client interface { // These methods define the operational structure of the client. // Start the client. Start must report an error if the client is running. - Start() error + Start(context.Context) error // Stop the client. Stop must report an error if the client is not running. Stop() error diff --git a/rpc/client/main_test.go b/rpc/client/main_test.go index c242f01c4..5ae9b951c 100644 --- a/rpc/client/main_test.go +++ b/rpc/client/main_test.go @@ -32,7 +32,6 @@ func NodeSuite(t *testing.T) (service.Service, *config.Config) { require.NoError(t, err) t.Cleanup(func() { cancel() - assert.NoError(t, node.Stop()) assert.NoError(t, closer(ctx)) assert.NoError(t, app.Close()) node.Wait() diff --git a/rpc/client/mocks/client.go b/rpc/client/mocks/client.go new file mode 100644 index 000000000..7012e1c2d --- /dev/null +++ b/rpc/client/mocks/client.go @@ -0,0 +1,756 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + bytes "github.com/tendermint/tendermint/libs/bytes" + client "github.com/tendermint/tendermint/rpc/client" + + context "context" + + coretypes "github.com/tendermint/tendermint/rpc/coretypes" + + mock "github.com/stretchr/testify/mock" + + types "github.com/tendermint/tendermint/types" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// ABCIInfo provides a mock function with given fields: _a0 +func (_m *Client) ABCIInfo(_a0 context.Context) (*coretypes.ResultABCIInfo, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultABCIInfo + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultABCIInfo); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultABCIInfo) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ABCIQuery provides a mock function with given fields: ctx, path, data +func (_m *Client) ABCIQuery(ctx context.Context, path string, data bytes.HexBytes) (*coretypes.ResultABCIQuery, error) { + ret := _m.Called(ctx, path, data) + + var r0 *coretypes.ResultABCIQuery + if rf, ok := ret.Get(0).(func(context.Context, string, bytes.HexBytes) *coretypes.ResultABCIQuery); ok { + r0 = rf(ctx, path, data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultABCIQuery) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, bytes.HexBytes) error); ok { + r1 = rf(ctx, path, data) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ABCIQueryWithOptions provides a mock function with given fields: ctx, path, data, opts +func (_m *Client) ABCIQueryWithOptions(ctx context.Context, path string, data bytes.HexBytes, opts client.ABCIQueryOptions) (*coretypes.ResultABCIQuery, error) { + ret := _m.Called(ctx, path, data, opts) + + var r0 *coretypes.ResultABCIQuery + if rf, ok := ret.Get(0).(func(context.Context, string, bytes.HexBytes, client.ABCIQueryOptions) *coretypes.ResultABCIQuery); ok { + r0 = rf(ctx, path, data, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultABCIQuery) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, bytes.HexBytes, client.ABCIQueryOptions) error); ok { + r1 = rf(ctx, path, data, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Block provides a mock function with given fields: ctx, height +func (_m *Client) Block(ctx context.Context, height *int64) (*coretypes.ResultBlock, error) { + ret := _m.Called(ctx, height) + + var r0 *coretypes.ResultBlock + if rf, ok := ret.Get(0).(func(context.Context, *int64) *coretypes.ResultBlock); ok { + r0 = rf(ctx, height) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBlock) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int64) error); ok { + r1 = rf(ctx, height) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BlockByHash provides a mock function with given fields: ctx, hash +func (_m *Client) BlockByHash(ctx context.Context, hash bytes.HexBytes) (*coretypes.ResultBlock, error) { + ret := _m.Called(ctx, hash) + + var r0 *coretypes.ResultBlock + if rf, ok := ret.Get(0).(func(context.Context, bytes.HexBytes) *coretypes.ResultBlock); ok { + r0 = rf(ctx, hash) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBlock) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, bytes.HexBytes) error); ok { + r1 = rf(ctx, hash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BlockResults provides a mock function with given fields: ctx, height +func (_m *Client) BlockResults(ctx context.Context, height *int64) (*coretypes.ResultBlockResults, error) { + ret := _m.Called(ctx, height) + + var r0 *coretypes.ResultBlockResults + if rf, ok := ret.Get(0).(func(context.Context, *int64) *coretypes.ResultBlockResults); ok { + r0 = rf(ctx, height) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBlockResults) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int64) error); ok { + r1 = rf(ctx, height) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BlockSearch provides a mock function with given fields: ctx, query, page, perPage, orderBy +func (_m *Client) BlockSearch(ctx context.Context, query string, page *int, perPage *int, orderBy string) (*coretypes.ResultBlockSearch, error) { + ret := _m.Called(ctx, query, page, perPage, orderBy) + + var r0 *coretypes.ResultBlockSearch + if rf, ok := ret.Get(0).(func(context.Context, string, *int, *int, string) *coretypes.ResultBlockSearch); ok { + r0 = rf(ctx, query, page, perPage, orderBy) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBlockSearch) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, *int, *int, string) error); ok { + r1 = rf(ctx, query, page, perPage, orderBy) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BlockchainInfo provides a mock function with given fields: ctx, minHeight, maxHeight +func (_m *Client) BlockchainInfo(ctx context.Context, minHeight int64, maxHeight int64) (*coretypes.ResultBlockchainInfo, error) { + ret := _m.Called(ctx, minHeight, maxHeight) + + var r0 *coretypes.ResultBlockchainInfo + if rf, ok := ret.Get(0).(func(context.Context, int64, int64) *coretypes.ResultBlockchainInfo); ok { + r0 = rf(ctx, minHeight, maxHeight) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBlockchainInfo) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { + r1 = rf(ctx, minHeight, maxHeight) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BroadcastEvidence provides a mock function with given fields: _a0, _a1 +func (_m *Client) BroadcastEvidence(_a0 context.Context, _a1 types.Evidence) (*coretypes.ResultBroadcastEvidence, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultBroadcastEvidence + if rf, ok := ret.Get(0).(func(context.Context, types.Evidence) *coretypes.ResultBroadcastEvidence); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBroadcastEvidence) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.Evidence) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BroadcastTxAsync provides a mock function with given fields: _a0, _a1 +func (_m *Client) BroadcastTxAsync(_a0 context.Context, _a1 types.Tx) (*coretypes.ResultBroadcastTx, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultBroadcastTx + if rf, ok := ret.Get(0).(func(context.Context, types.Tx) *coretypes.ResultBroadcastTx); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBroadcastTx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.Tx) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BroadcastTxCommit provides a mock function with given fields: _a0, _a1 +func (_m *Client) BroadcastTxCommit(_a0 context.Context, _a1 types.Tx) (*coretypes.ResultBroadcastTxCommit, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultBroadcastTxCommit + if rf, ok := ret.Get(0).(func(context.Context, types.Tx) *coretypes.ResultBroadcastTxCommit); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBroadcastTxCommit) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.Tx) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BroadcastTxSync provides a mock function with given fields: _a0, _a1 +func (_m *Client) BroadcastTxSync(_a0 context.Context, _a1 types.Tx) (*coretypes.ResultBroadcastTx, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultBroadcastTx + if rf, ok := ret.Get(0).(func(context.Context, types.Tx) *coretypes.ResultBroadcastTx); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultBroadcastTx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.Tx) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CheckTx provides a mock function with given fields: _a0, _a1 +func (_m *Client) CheckTx(_a0 context.Context, _a1 types.Tx) (*coretypes.ResultCheckTx, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultCheckTx + if rf, ok := ret.Get(0).(func(context.Context, types.Tx) *coretypes.ResultCheckTx); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultCheckTx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, types.Tx) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Commit provides a mock function with given fields: ctx, height +func (_m *Client) Commit(ctx context.Context, height *int64) (*coretypes.ResultCommit, error) { + ret := _m.Called(ctx, height) + + var r0 *coretypes.ResultCommit + if rf, ok := ret.Get(0).(func(context.Context, *int64) *coretypes.ResultCommit); ok { + r0 = rf(ctx, height) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultCommit) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int64) error); ok { + r1 = rf(ctx, height) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ConsensusParams provides a mock function with given fields: ctx, height +func (_m *Client) ConsensusParams(ctx context.Context, height *int64) (*coretypes.ResultConsensusParams, error) { + ret := _m.Called(ctx, height) + + var r0 *coretypes.ResultConsensusParams + if rf, ok := ret.Get(0).(func(context.Context, *int64) *coretypes.ResultConsensusParams); ok { + r0 = rf(ctx, height) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultConsensusParams) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int64) error); ok { + r1 = rf(ctx, height) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ConsensusState provides a mock function with given fields: _a0 +func (_m *Client) ConsensusState(_a0 context.Context) (*coretypes.ResultConsensusState, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultConsensusState + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultConsensusState); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultConsensusState) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DumpConsensusState provides a mock function with given fields: _a0 +func (_m *Client) DumpConsensusState(_a0 context.Context) (*coretypes.ResultDumpConsensusState, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultDumpConsensusState + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultDumpConsensusState); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultDumpConsensusState) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Genesis provides a mock function with given fields: _a0 +func (_m *Client) Genesis(_a0 context.Context) (*coretypes.ResultGenesis, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultGenesis + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultGenesis); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultGenesis) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GenesisChunked provides a mock function with given fields: _a0, _a1 +func (_m *Client) GenesisChunked(_a0 context.Context, _a1 uint) (*coretypes.ResultGenesisChunk, error) { + ret := _m.Called(_a0, _a1) + + var r0 *coretypes.ResultGenesisChunk + if rf, ok := ret.Get(0).(func(context.Context, uint) *coretypes.ResultGenesisChunk); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultGenesisChunk) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, uint) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Health provides a mock function with given fields: _a0 +func (_m *Client) Health(_a0 context.Context) (*coretypes.ResultHealth, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultHealth + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultHealth); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultHealth) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsRunning provides a mock function with given fields: +func (_m *Client) IsRunning() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// NetInfo provides a mock function with given fields: _a0 +func (_m *Client) NetInfo(_a0 context.Context) (*coretypes.ResultNetInfo, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultNetInfo + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultNetInfo); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultNetInfo) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NumUnconfirmedTxs provides a mock function with given fields: _a0 +func (_m *Client) NumUnconfirmedTxs(_a0 context.Context) (*coretypes.ResultUnconfirmedTxs, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultUnconfirmedTxs + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultUnconfirmedTxs); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultUnconfirmedTxs) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemoveTx provides a mock function with given fields: _a0, _a1 +func (_m *Client) RemoveTx(_a0 context.Context, _a1 types.TxKey) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.TxKey) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *Client) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Status provides a mock function with given fields: _a0 +func (_m *Client) Status(_a0 context.Context) (*coretypes.ResultStatus, error) { + ret := _m.Called(_a0) + + var r0 *coretypes.ResultStatus + if rf, ok := ret.Get(0).(func(context.Context) *coretypes.ResultStatus); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultStatus) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Stop provides a mock function with given fields: +func (_m *Client) Stop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Subscribe provides a mock function with given fields: ctx, subscriber, query, outCapacity +func (_m *Client) Subscribe(ctx context.Context, subscriber string, query string, outCapacity ...int) (<-chan coretypes.ResultEvent, error) { + _va := make([]interface{}, len(outCapacity)) + for _i := range outCapacity { + _va[_i] = outCapacity[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, subscriber, query) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 <-chan coretypes.ResultEvent + if rf, ok := ret.Get(0).(func(context.Context, string, string, ...int) <-chan coretypes.ResultEvent); ok { + r0 = rf(ctx, subscriber, query, outCapacity...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan coretypes.ResultEvent) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string, ...int) error); ok { + r1 = rf(ctx, subscriber, query, outCapacity...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Tx provides a mock function with given fields: ctx, hash, prove +func (_m *Client) Tx(ctx context.Context, hash bytes.HexBytes, prove bool) (*coretypes.ResultTx, error) { + ret := _m.Called(ctx, hash, prove) + + var r0 *coretypes.ResultTx + if rf, ok := ret.Get(0).(func(context.Context, bytes.HexBytes, bool) *coretypes.ResultTx); ok { + r0 = rf(ctx, hash, prove) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultTx) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, bytes.HexBytes, bool) error); ok { + r1 = rf(ctx, hash, prove) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TxSearch provides a mock function with given fields: ctx, query, prove, page, perPage, orderBy +func (_m *Client) TxSearch(ctx context.Context, query string, prove bool, page *int, perPage *int, orderBy string) (*coretypes.ResultTxSearch, error) { + ret := _m.Called(ctx, query, prove, page, perPage, orderBy) + + var r0 *coretypes.ResultTxSearch + if rf, ok := ret.Get(0).(func(context.Context, string, bool, *int, *int, string) *coretypes.ResultTxSearch); ok { + r0 = rf(ctx, query, prove, page, perPage, orderBy) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultTxSearch) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, bool, *int, *int, string) error); ok { + r1 = rf(ctx, query, prove, page, perPage, orderBy) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UnconfirmedTxs provides a mock function with given fields: ctx, limit +func (_m *Client) UnconfirmedTxs(ctx context.Context, limit *int) (*coretypes.ResultUnconfirmedTxs, error) { + ret := _m.Called(ctx, limit) + + var r0 *coretypes.ResultUnconfirmedTxs + if rf, ok := ret.Get(0).(func(context.Context, *int) *coretypes.ResultUnconfirmedTxs); ok { + r0 = rf(ctx, limit) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultUnconfirmedTxs) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int) error); ok { + r1 = rf(ctx, limit) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Unsubscribe provides a mock function with given fields: ctx, subscriber, query +func (_m *Client) Unsubscribe(ctx context.Context, subscriber string, query string) error { + ret := _m.Called(ctx, subscriber, query) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, subscriber, query) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UnsubscribeAll provides a mock function with given fields: ctx, subscriber +func (_m *Client) UnsubscribeAll(ctx context.Context, subscriber string) error { + ret := _m.Called(ctx, subscriber) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, subscriber) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Validators provides a mock function with given fields: ctx, height, page, perPage +func (_m *Client) Validators(ctx context.Context, height *int64, page *int, perPage *int) (*coretypes.ResultValidators, error) { + ret := _m.Called(ctx, height, page, perPage) + + var r0 *coretypes.ResultValidators + if rf, ok := ret.Get(0).(func(context.Context, *int64, *int, *int) *coretypes.ResultValidators); ok { + r0 = rf(ctx, height, page, perPage) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*coretypes.ResultValidators) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *int64, *int, *int) error); ok { + r1 = rf(ctx, height, page, perPage) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/rpc/client/rpc_test.go b/rpc/client/rpc_test.go index 38766e047..12c13d686 100644 --- a/rpc/client/rpc_test.go +++ b/rpc/client/rpc_test.go @@ -463,7 +463,7 @@ func TestClientMethodCalls(t *testing.T) { // start for this test it if it wasn't already running if !c.IsRunning() { // if so, then we start it, listen, and stop it. - err := c.Start() + err := c.Start(ctx) require.Nil(t, err) t.Cleanup(func() { if err := c.Stop(); err != nil { diff --git a/rpc/test/helpers.go b/rpc/test/helpers.go index 7bdc1bb2e..90c3b2e49 100644 --- a/rpc/test/helpers.go +++ b/rpc/test/helpers.go @@ -73,10 +73,13 @@ func CreateConfig(testName string) (*config.Config, error) { type ServiceCloser func(context.Context) error -func StartTendermint(ctx context.Context, +func StartTendermint( + ctx context.Context, conf *config.Config, app abci.Application, - opts ...func(*Options)) (service.Service, ServiceCloser, error) { + opts ...func(*Options), +) (service.Service, ServiceCloser, error) { + ctx, cancel := context.WithCancel(ctx) nodeOpts := &Options{} for _, opt := range opts { @@ -89,14 +92,14 @@ func StartTendermint(ctx context.Context, logger = log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false) } papp := abciclient.NewLocalCreator(app) - tmNode, err := node.New(conf, logger, papp, nil) + tmNode, err := node.New(ctx, conf, logger, papp, nil) if err != nil { - return nil, func(_ context.Context) error { return nil }, err + return nil, func(_ context.Context) error { cancel(); return nil }, err } - err = tmNode.Start() + err = tmNode.Start(ctx) if err != nil { - return nil, func(_ context.Context) error { return nil }, err + return nil, func(_ context.Context) error { cancel(); return nil }, err } waitForRPC(ctx, conf) @@ -106,9 +109,7 @@ func StartTendermint(ctx context.Context, } return tmNode, func(ctx context.Context) error { - if err := tmNode.Stop(); err != nil { - logger.Error("Error when trying to stop node", "err", err) - } + cancel() tmNode.Wait() os.RemoveAll(conf.RootDir) return nil diff --git a/test/e2e/node/main.go b/test/e2e/node/main.go index 7a9e67915..5d25b0195 100644 --- a/test/e2e/node/main.go +++ b/test/e2e/node/main.go @@ -38,6 +38,9 @@ var logger = log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, fals // main is the binary entrypoint. func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if len(os.Args) != 2 { fmt.Printf("Usage: %v ", os.Args[0]) return @@ -47,14 +50,14 @@ func main() { configFile = os.Args[1] } - if err := run(configFile); err != nil { + if err := run(ctx, configFile); err != nil { logger.Error(err.Error()) os.Exit(1) } } // run runs the application - basically like main() with error handling. -func run(configFile string) error { +func run(ctx context.Context, configFile string) error { cfg, err := LoadConfig(configFile) if err != nil { return err @@ -62,7 +65,7 @@ func run(configFile string) error { // Start remote signer (must start before node if running builtin). if cfg.PrivValServer != "" { - if err = startSigner(cfg); err != nil { + if err = startSigner(ctx, cfg); err != nil { return err } if cfg.Protocol == "builtin" { @@ -73,15 +76,15 @@ func run(configFile string) error { // Start app server. switch cfg.Protocol { case "socket", "grpc": - err = startApp(cfg) + err = startApp(ctx, cfg) case "builtin": switch cfg.Mode { case string(e2e.ModeLight): - err = startLightNode(cfg) + err = startLightNode(ctx, cfg) case string(e2e.ModeSeed): - err = startSeedNode() + err = startSeedNode(ctx) default: - err = startNode(cfg) + err = startNode(ctx, cfg) } default: err = fmt.Errorf("invalid protocol %q", cfg.Protocol) @@ -97,7 +100,7 @@ func run(configFile string) error { } // startApp starts the application server, listening for connections from Tendermint. -func startApp(cfg *Config) error { +func startApp(ctx context.Context, cfg *Config) error { app, err := app.NewApplication(cfg.App()) if err != nil { return err @@ -106,7 +109,7 @@ func startApp(cfg *Config) error { if err != nil { return err } - err = server.Start() + err = server.Start(ctx) if err != nil { return err } @@ -118,7 +121,7 @@ func startApp(cfg *Config) error { // configuration is in $TMHOME/config/tendermint.toml. // // FIXME There is no way to simply load the configuration from a file, so we need to pull in Viper. -func startNode(cfg *Config) error { +func startNode(ctx context.Context, cfg *Config) error { app, err := app.NewApplication(cfg.App()) if err != nil { return err @@ -129,7 +132,9 @@ func startNode(cfg *Config) error { return fmt.Errorf("failed to setup config: %w", err) } - n, err := node.New(tmcfg, + n, err := node.New( + ctx, + tmcfg, nodeLogger, abciclient.NewLocalCreator(app), nil, @@ -137,10 +142,10 @@ func startNode(cfg *Config) error { if err != nil { return err } - return n.Start() + return n.Start(ctx) } -func startSeedNode() error { +func startSeedNode(ctx context.Context) error { tmcfg, nodeLogger, err := setupNode() if err != nil { return fmt.Errorf("failed to setup config: %w", err) @@ -148,14 +153,14 @@ func startSeedNode() error { tmcfg.Mode = config.ModeSeed - n, err := node.New(tmcfg, nodeLogger, nil, nil) + n, err := node.New(ctx, tmcfg, nodeLogger, nil, nil) if err != nil { return err } - return n.Start() + return n.Start(ctx) } -func startLightNode(cfg *Config) error { +func startLightNode(ctx context.Context, cfg *Config) error { tmcfg, nodeLogger, err := setupNode() if err != nil { return err @@ -204,7 +209,7 @@ func startLightNode(cfg *Config) error { } logger.Info("Starting proxy...", "laddr", tmcfg.RPC.ListenAddress) - if err := p.ListenAndServe(); err != http.ErrServerClosed { + if err := p.ListenAndServe(ctx); err != http.ErrServerClosed { // Error starting or closing listener: logger.Error("proxy ListenAndServe", "err", err) } @@ -213,7 +218,7 @@ func startLightNode(cfg *Config) error { } // startSigner starts a signer server connecting to the given endpoint. -func startSigner(cfg *Config) error { +func startSigner(ctx context.Context, cfg *Config) error { filePV, err := privval.LoadFilePV(cfg.PrivValKey, cfg.PrivValState) if err != nil { return err @@ -251,7 +256,8 @@ func startSigner(cfg *Config) error { endpoint := privval.NewSignerDialerEndpoint(logger, dialFn, privval.SignerDialerEndpointRetryWaitInterval(1*time.Second), privval.SignerDialerEndpointConnRetries(100)) - err = privval.NewSignerServer(endpoint, cfg.ChainID, filePV).Start() + + err = privval.NewSignerServer(endpoint, cfg.ChainID, filePV).Start(ctx) if err != nil { return err } diff --git a/test/fuzz/mempool/checktx.go b/test/fuzz/mempool/checktx.go index 406e062fd..ba60d72cc 100644 --- a/test/fuzz/mempool/checktx.go +++ b/test/fuzz/mempool/checktx.go @@ -17,7 +17,7 @@ func init() { app := kvstore.NewApplication() cc := abciclient.NewLocalCreator(app) appConnMem, _ := cc(log.NewNopLogger()) - err := appConnMem.Start() + err := appConnMem.Start(context.TODO()) if err != nil { panic(err) } diff --git a/tools/tm-signer-harness/internal/test_harness.go b/tools/tm-signer-harness/internal/test_harness.go index 96eaaaff0..c8b5cd81d 100644 --- a/tools/tm-signer-harness/internal/test_harness.go +++ b/tools/tm-signer-harness/internal/test_harness.go @@ -89,7 +89,7 @@ type timeoutError interface { // NewTestHarness will load Tendermint data from the given files (including // validator public/private keypairs and chain details) and create a new // harness. -func NewTestHarness(logger log.Logger, cfg TestHarnessConfig) (*TestHarness, error) { +func NewTestHarness(ctx context.Context, logger log.Logger, cfg TestHarnessConfig) (*TestHarness, error) { keyFile := ExpandPath(cfg.KeyFile) stateFile := ExpandPath(cfg.StateFile) logger.Info("Loading private validator configuration", "keyFile", keyFile, "stateFile", stateFile) @@ -113,7 +113,7 @@ func NewTestHarness(logger log.Logger, cfg TestHarnessConfig) (*TestHarness, err return nil, newTestHarnessError(ErrFailedToCreateListener, err, "") } - signerClient, err := privval.NewSignerClient(spv, st.ChainID) + signerClient, err := privval.NewSignerClient(ctx, spv, st.ChainID) if err != nil { return nil, newTestHarnessError(ErrFailedToCreateListener, err, "") } diff --git a/tools/tm-signer-harness/internal/test_harness_test.go b/tools/tm-signer-harness/internal/test_harness_test.go index 85a589185..2ef630555 100644 --- a/tools/tm-signer-harness/internal/test_harness_test.go +++ b/tools/tm-signer-harness/internal/test_harness_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "os" "testing" @@ -73,17 +74,24 @@ const ( ) func TestRemoteSignerTestHarnessMaxAcceptRetriesReached(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg := makeConfig(t, 1, 2) defer cleanup(cfg) - th, err := NewTestHarness(log.TestingLogger(), cfg) + th, err := NewTestHarness(ctx, log.TestingLogger(), cfg) require.NoError(t, err) th.Run() assert.Equal(t, ErrMaxAcceptRetriesReached, th.exitCode) } func TestRemoteSignerTestHarnessSuccessfulRun(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + harnessTest( + ctx, t, func(th *TestHarness) *privval.SignerServer { return newMockSignerServer(t, th, th.fpv.Key.PrivKey, false, false) @@ -93,7 +101,11 @@ func TestRemoteSignerTestHarnessSuccessfulRun(t *testing.T) { } func TestRemoteSignerPublicKeyCheckFailed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + harnessTest( + ctx, t, func(th *TestHarness) *privval.SignerServer { return newMockSignerServer(t, th, ed25519.GenPrivKey(), false, false) @@ -103,7 +115,11 @@ func TestRemoteSignerPublicKeyCheckFailed(t *testing.T) { } func TestRemoteSignerProposalSigningFailed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + harnessTest( + ctx, t, func(th *TestHarness) *privval.SignerServer { return newMockSignerServer(t, th, th.fpv.Key.PrivKey, true, false) @@ -113,7 +129,11 @@ func TestRemoteSignerProposalSigningFailed(t *testing.T) { } func TestRemoteSignerVoteSigningFailed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + harnessTest( + ctx, t, func(th *TestHarness) *privval.SignerServer { return newMockSignerServer(t, th, th.fpv.Key.PrivKey, false, true) @@ -144,11 +164,16 @@ func newMockSignerServer( } // For running relatively standard tests. -func harnessTest(t *testing.T, signerServerMaker func(th *TestHarness) *privval.SignerServer, expectedExitCode int) { +func harnessTest( + ctx context.Context, + t *testing.T, + signerServerMaker func(th *TestHarness) *privval.SignerServer, + expectedExitCode int, +) { cfg := makeConfig(t, 100, 3) defer cleanup(cfg) - th, err := NewTestHarness(log.TestingLogger(), cfg) + th, err := NewTestHarness(ctx, log.TestingLogger(), cfg) require.NoError(t, err) donec := make(chan struct{}) go func() { @@ -157,7 +182,7 @@ func harnessTest(t *testing.T, signerServerMaker func(th *TestHarness) *privval. }() ss := signerServerMaker(th) - require.NoError(t, ss.Start()) + require.NoError(t, ss.Start(ctx)) assert.True(t, ss.IsRunning()) defer ss.Stop() //nolint:errcheck // ignore for tests diff --git a/tools/tm-signer-harness/main.go b/tools/tm-signer-harness/main.go index 90afd7d1f..4bf1933e0 100644 --- a/tools/tm-signer-harness/main.go +++ b/tools/tm-signer-harness/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -115,7 +116,7 @@ Usage: } } -func runTestHarness(acceptRetries int, bindAddr, tmhome string) { +func runTestHarness(ctx context.Context, acceptRetries int, bindAddr, tmhome string) { tmhome = internal.ExpandPath(tmhome) cfg := internal.TestHarnessConfig{ BindAddr: bindAddr, @@ -128,7 +129,7 @@ func runTestHarness(acceptRetries int, bindAddr, tmhome string) { SecretConnKey: ed25519.GenPrivKey(), ExitWhenComplete: true, } - harness, err := internal.NewTestHarness(logger, cfg) + harness, err := internal.NewTestHarness(ctx, logger, cfg) if err != nil { logger.Error(err.Error()) if therr, ok := err.(*internal.TestHarnessError); ok { @@ -156,6 +157,9 @@ func extractKey(tmhome, outputPath string) { } func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := rootCmd.Parse(os.Args[1:]); err != nil { fmt.Printf("Error parsing flags: %v\n", err) os.Exit(1) @@ -183,7 +187,7 @@ func main() { fmt.Printf("Error parsing flags: %v\n", err) os.Exit(1) } - runTestHarness(flagAcceptRetries, flagBindAddr, flagTMHome) + runTestHarness(ctx, flagAcceptRetries, flagBindAddr, flagTMHome) case "extract_key": if err := extractKeyCmd.Parse(os.Args[2:]); err != nil { fmt.Printf("Error parsing flags: %v\n", err) diff --git a/types/block_test.go b/types/block_test.go index 1c762653b..e0c1ab3be 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -555,6 +555,9 @@ func TestCommitToVoteSetWithVotesForNilBlock(t *testing.T) { round = 0 ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + type commitVoteTest struct { blockIDs []BlockID numVotes []int // must sum to numValidators @@ -572,7 +575,7 @@ func TestCommitToVoteSetWithVotesForNilBlock(t *testing.T) { vi := int32(0) for n := range tc.blockIDs { for i := 0; i < tc.numVotes[n]; i++ { - pubKey, err := vals[vi].GetPubKey(context.Background()) + pubKey, err := vals[vi].GetPubKey(ctx) require.NoError(t, err) vote := &Vote{ ValidatorAddress: pubKey.Address(),