diff --git a/internal/p2p/mocks/connection.go b/internal/p2p/mocks/connection.go index 766bbf657..20727f8a6 100644 --- a/internal/p2p/mocks/connection.go +++ b/internal/p2p/mocks/connection.go @@ -13,6 +13,8 @@ import ( p2p "github.com/tendermint/tendermint/internal/p2p" + time "time" + types "github.com/tendermint/tendermint/types" ) @@ -35,20 +37,20 @@ func (_m *Connection) Close() error { return r0 } -// Handshake provides a mock function with given fields: _a0, _a1, _a2 -func (_m *Connection) Handshake(_a0 context.Context, _a1 types.NodeInfo, _a2 crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) { - ret := _m.Called(_a0, _a1, _a2) +// Handshake provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *Connection) Handshake(_a0 context.Context, _a1 time.Duration, _a2 types.NodeInfo, _a3 crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) var r0 types.NodeInfo - if rf, ok := ret.Get(0).(func(context.Context, types.NodeInfo, crypto.PrivKey) types.NodeInfo); ok { - r0 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(0).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) types.NodeInfo); ok { + r0 = rf(_a0, _a1, _a2, _a3) } else { r0 = ret.Get(0).(types.NodeInfo) } var r1 crypto.PubKey - if rf, ok := ret.Get(1).(func(context.Context, types.NodeInfo, crypto.PrivKey) crypto.PubKey); ok { - r1 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(1).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) crypto.PubKey); ok { + r1 = rf(_a0, _a1, _a2, _a3) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(crypto.PubKey) @@ -56,8 +58,8 @@ func (_m *Connection) Handshake(_a0 context.Context, _a1 types.NodeInfo, _a2 cry } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, types.NodeInfo, crypto.PrivKey) error); ok { - r2 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(2).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) error); ok { + r2 = rf(_a0, _a1, _a2, _a3) } else { r2 = ret.Error(2) } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index e3adc77ee..d5edd42be 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -710,14 +710,8 @@ func (r *Router) handshakePeer( expectID types.NodeID, ) (types.NodeInfo, error) { - if r.options.HandshakeTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, r.options.HandshakeTimeout) - defer cancel() - } - nodeInfo := r.nodeInfoProducer() - peerInfo, peerKey, err := conn.Handshake(ctx, *nodeInfo, r.privKey) + peerInfo, peerKey, err := conn.Handshake(ctx, r.options.HandshakeTimeout, *nodeInfo, r.privKey) if err != nil { return peerInfo, err } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index 86af19385..8a58146fd 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -385,7 +385,7 @@ func TestRouter_AcceptPeers(t *testing.T) { connCtx, connCancel := context.WithCancel(context.Background()) mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) mockConnection.On("Close").Run(func(_ mock.Arguments) { connCancel() }).Return(nil).Maybe() mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -500,7 +500,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). WaitUntil(closeCh).Return(types.NodeInfo{}, nil, io.EOF) mockConnection.On("Close").Return(nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -588,7 +588,7 @@ func TestRouter_DialPeers(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") if tc.dialErr == nil { - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) mockConnection.On("Close").Run(func(_ mock.Arguments) { connCancel() }).Return(nil).Maybe() } @@ -674,7 +674,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). WaitUntil(closeCh).Return(types.NodeInfo{}, nil, io.EOF) mockConnection.On("Close").Return(nil) @@ -757,7 +757,7 @@ func TestRouter_EvictPeers(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(peerInfo, peerKey.PubKey(), nil) mockConnection.On("ReceiveMessage", mock.Anything).WaitUntil(closeCh).Return(chID, nil, io.EOF) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -826,7 +826,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(incompatiblePeer, peerKey.PubKey(), nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Return(nil) @@ -877,7 +877,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(peer, peerKey.PubKey(), nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Return(nil) diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 7a965260a..e644a11ae 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/types" @@ -81,7 +82,7 @@ type Connection interface { // FIXME: The handshake should really be the Router's responsibility, but // that requires the connection interface to be byte-oriented rather than // message-oriented (see comment above). - Handshake(context.Context, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) + Handshake(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) // ReceiveMessage returns the next message received on the connection, // blocking until one is available. Returns io.EOF if closed. diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 7bf17d1a0..13a65b973 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "sync" + "time" "golang.org/x/net/netutil" @@ -274,6 +275,7 @@ func newMConnConnection( // Handshake implements Connection. func (c *mConnConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { @@ -283,6 +285,12 @@ func (c *mConnConnection) Handshake( peerKey crypto.PubKey errCh = make(chan error, 1) ) + handshakeCtx := ctx + if timeout > 0 { + var cancel context.CancelFunc + handshakeCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } // To handle context cancellation, we need to do the handshake in a // goroutine and abort the blocking network calls by closing the connection // when the context is canceled. @@ -295,17 +303,17 @@ func (c *mConnConnection) Handshake( } }() var err error - mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) + mconn, peerInfo, peerKey, err = c.handshake(handshakeCtx, nodeInfo, privKey) select { case errCh <- err: - case <-ctx.Done(): + case <-handshakeCtx.Done(): } }() select { - case <-ctx.Done(): + case <-handshakeCtx.Done(): _ = c.Close() return types.NodeInfo{}, nil, ctx.Err() @@ -314,6 +322,10 @@ func (c *mConnConnection) Handshake( return types.NodeInfo{}, nil, err } c.mconn = mconn + // Start must not use the handshakeCtx. The handshakeCtx may have a + // timeout set that is intended to terminate only the handshake procedure. + // The context passed to Start controls the entire lifecycle of the + // mconn. if err = c.mconn.Start(ctx); err != nil { return types.NodeInfo{}, nil, err } diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 3eb4c5b51..c321bc174 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "time" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" @@ -273,9 +274,16 @@ func (c *MemoryConnection) RemoteEndpoint() Endpoint { // Handshake implements Connection. func (c *MemoryConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + select { case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}: c.logger.Debug("sent handshake", "nodeInfo", nodeInfo) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index b4edf9bc9..d58c23955 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -296,7 +296,7 @@ func TestConnection_Handshake(t *testing.T) { errCh := make(chan error, 1) go func() { // Must use assert due to goroutine. - peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey) + peerInfo, peerKey, err := ba.Handshake(ctx, 0, bInfo, bKey) if err == nil { assert.Equal(t, aInfo, peerInfo) assert.Equal(t, aKey.PubKey(), peerKey) @@ -307,7 +307,7 @@ func TestConnection_Handshake(t *testing.T) { } }() - peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey) + peerInfo, peerKey, err := ab.Handshake(ctx, 0, aInfo, aKey) require.NoError(t, err) require.Equal(t, bInfo, peerInfo) require.Equal(t, bKey.PubKey(), peerKey) @@ -328,7 +328,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { ab, ba := dialAccept(ctx, t, a, b) timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) cancel() - _, _, err := ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) + _, _, err := ab.Handshake(timeoutCtx, 0, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.Canceled, err) _ = ab.Close() @@ -338,7 +338,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { ab, ba = dialAccept(ctx, t, a, b) timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() - _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) + _, _, err = ab.Handshake(timeoutCtx, 0, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) _ = ab.Close() @@ -642,13 +642,13 @@ func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport) go func() { privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ba.Handshake(ctx, nodeInfo, privKey) + _, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey) errCh <- err }() privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ab.Handshake(ctx, nodeInfo, privKey) + _, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey) require.NoError(t, err) timer := time.NewTimer(2 * time.Second) diff --git a/node/node.go b/node/node.go index 77773044b..187a7a1d6 100644 --- a/node/node.go +++ b/node/node.go @@ -715,7 +715,9 @@ func loadStateFromDBOrGenesisDocProvider(stateStore sm.Store, genDoc *types.Gene func getRouterConfig(conf *config.Config, appClient abciclient.Client) p2p.RouterOptions { opts := p2p.RouterOptions{ - QueueType: conf.P2P.QueueType, + QueueType: conf.P2P.QueueType, + HandshakeTimeout: conf.P2P.HandshakeTimeout, + DialTimeout: conf.P2P.DialTimeout, } if conf.FilterPeers && appClient != nil {