diff --git a/internal/p2p/mocks/connection.go b/internal/p2p/mocks/connection.go index 0a8606da7..e5ba9584a 100644 --- a/internal/p2p/mocks/connection.go +++ b/internal/p2p/mocks/connection.go @@ -13,6 +13,8 @@ import ( p2p "github.com/tendermint/tendermint/internal/p2p" + time "time" + types "github.com/tendermint/tendermint/types" ) @@ -49,20 +51,20 @@ func (_m *Connection) FlushClose() error { return r0 } -// Handshake provides a mock function with given fields: _a0, _a1, _a2 -func (_m *Connection) Handshake(_a0 context.Context, _a1 types.NodeInfo, _a2 crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) { - ret := _m.Called(_a0, _a1, _a2) +// Handshake provides a mock function with given fields: _a0, _a1, _a2, _a3 +func (_m *Connection) Handshake(_a0 context.Context, _a1 time.Duration, _a2 types.NodeInfo, _a3 crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) { + ret := _m.Called(_a0, _a1, _a2, _a3) var r0 types.NodeInfo - if rf, ok := ret.Get(0).(func(context.Context, types.NodeInfo, crypto.PrivKey) types.NodeInfo); ok { - r0 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(0).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) types.NodeInfo); ok { + r0 = rf(_a0, _a1, _a2, _a3) } else { r0 = ret.Get(0).(types.NodeInfo) } var r1 crypto.PubKey - if rf, ok := ret.Get(1).(func(context.Context, types.NodeInfo, crypto.PrivKey) crypto.PubKey); ok { - r1 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(1).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) crypto.PubKey); ok { + r1 = rf(_a0, _a1, _a2, _a3) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(crypto.PubKey) @@ -70,8 +72,8 @@ func (_m *Connection) Handshake(_a0 context.Context, _a1 types.NodeInfo, _a2 cry } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, types.NodeInfo, crypto.PrivKey) error); ok { - r2 = rf(_a0, _a1, _a2) + if rf, ok := ret.Get(2).(func(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) error); ok { + r2 = rf(_a0, _a1, _a2, _a3) } else { r2 = ret.Error(2) } diff --git a/internal/p2p/peer_test.go b/internal/p2p/peer_test.go index dfe7bc798..dad7b98b5 100644 --- a/internal/p2p/peer_test.go +++ b/internal/p2p/peer_test.go @@ -90,7 +90,7 @@ func createOutboundPeerAndPerformHandshake( if err != nil { return nil, err } - peerInfo, _, err := pc.conn.Handshake(context.Background(), ourNodeInfo, pk) + peerInfo, _, err := pc.conn.Handshake(context.Background(), 0, ourNodeInfo, pk) if err != nil { return nil, err } @@ -187,7 +187,7 @@ func (rp *remotePeer) Dial(addr *NetAddress) (net.Conn, error) { if err != nil { return nil, err } - _, _, err = pc.conn.Handshake(context.Background(), rp.nodeInfo(), rp.PrivKey) + _, _, err = pc.conn.Handshake(context.Background(), 0, rp.nodeInfo(), rp.PrivKey) if err != nil { return nil, err } @@ -213,7 +213,7 @@ func (rp *remotePeer) accept() { if err != nil { golog.Printf("Failed to create a peer: %+v", err) } - _, _, err = pc.conn.Handshake(context.Background(), rp.nodeInfo(), rp.PrivKey) + _, _, err = pc.conn.Handshake(context.Background(), 0, rp.nodeInfo(), rp.PrivKey) if err != nil { golog.Printf("Failed to handshake a peer: %+v", err) } diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 81a8928a4..dc717565f 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -829,13 +829,7 @@ func (r *Router) handshakePeer( expectID types.NodeID, ) (types.NodeInfo, crypto.PubKey, error) { - if r.options.HandshakeTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, r.options.HandshakeTimeout) - defer cancel() - } - - peerInfo, peerKey, err := conn.Handshake(ctx, r.nodeInfo, r.privKey) + peerInfo, peerKey, err := conn.Handshake(ctx, r.options.HandshakeTimeout, r.nodeInfo, r.privKey) if err != nil { return peerInfo, peerKey, err } diff --git a/internal/p2p/router_test.go b/internal/p2p/router_test.go index e8494fdf4..37ce95768 100644 --- a/internal/p2p/router_test.go +++ b/internal/p2p/router_test.go @@ -352,7 +352,7 @@ func TestRouter_AcceptPeers(t *testing.T) { closer := tmsync.NewCloser() mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -462,7 +462,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). WaitUntil(closeCh).Return(types.NodeInfo{}, nil, io.EOF) mockConnection.On("Close").Return(nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -543,7 +543,7 @@ func TestRouter_DialPeers(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") if tc.dialErr == nil { - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil) } @@ -630,7 +630,7 @@ func TestRouter_DialPeers_Parallel(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). WaitUntil(closeCh).Return(types.NodeInfo{}, nil, io.EOF) mockConnection.On("Close").Return(nil) @@ -710,7 +710,7 @@ func TestRouter_EvictPeers(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(peerInfo, peerKey.PubKey(), nil) mockConnection.On("ReceiveMessage").WaitUntil(closeCh).Return(chID, nil, io.EOF) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) @@ -779,7 +779,7 @@ func TestRouter_ChannelCompatability(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(incompatiblePeer, peerKey.PubKey(), nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Return(nil) @@ -828,7 +828,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) { mockConnection := &mocks.Connection{} mockConnection.On("String").Maybe().Return("mock") - mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). + mockConnection.On("Handshake", mock.Anything, mock.Anything, selfInfo, selfKey). Return(peer, peerKey.PubKey(), nil) mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Return(nil) diff --git a/internal/p2p/switch.go b/internal/p2p/switch.go index 60c0c7deb..9e1b0311b 100644 --- a/internal/p2p/switch.go +++ b/internal/p2p/switch.go @@ -865,11 +865,11 @@ func (sw *Switch) handshakePeer( c Connection, expectPeerID types.NodeID, ) (types.NodeInfo, crypto.PubKey, error) { - // Moved from transport and hardcoded until legacy P2P stack removal. - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - peerInfo, peerKey, err := c.Handshake(ctx, sw.nodeInfo, sw.nodeKey.PrivKey) + // Moved timeout from transport and hardcoded until legacy P2P stack removal. + peerInfo, peerKey, err := c.Handshake(ctx, 5*time.Second, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { return peerInfo, peerKey, ErrRejected{ conn: c.(*mConnConnection).conn, diff --git a/internal/p2p/switch_test.go b/internal/p2p/switch_test.go index 8cb755c9f..c68cfceaf 100644 --- a/internal/p2p/switch_test.go +++ b/internal/p2p/switch_test.go @@ -267,7 +267,7 @@ func TestSwitchPeerFilter(t *testing.T) { if err != nil { t.Fatal(err) } - peerInfo, _, err := c.Handshake(ctx, sw.nodeInfo, sw.nodeKey.PrivKey) + peerInfo, _, err := c.Handshake(ctx, 0, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { t.Fatal(err) } @@ -324,7 +324,7 @@ func TestSwitchPeerFilterTimeout(t *testing.T) { if err != nil { t.Fatal(err) } - peerInfo, _, err := c.Handshake(ctx, sw.nodeInfo, sw.nodeKey.PrivKey) + peerInfo, _, err := c.Handshake(ctx, 0, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { t.Fatal(err) } @@ -360,7 +360,7 @@ func TestSwitchPeerFilterDuplicate(t *testing.T) { if err != nil { t.Fatal(err) } - peerInfo, _, err := c.Handshake(ctx, sw.nodeInfo, sw.nodeKey.PrivKey) + peerInfo, _, err := c.Handshake(ctx, 0, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { t.Fatal(err) } @@ -415,7 +415,7 @@ func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { if err != nil { t.Fatal(err) } - peerInfo, _, err := c.Handshake(ctx, sw.nodeInfo, sw.nodeKey.PrivKey) + peerInfo, _, err := c.Handshake(ctx, 0, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { t.Fatal(err) } diff --git a/internal/p2p/test_util.go b/internal/p2p/test_util.go index b2851646d..ae21ba4d7 100644 --- a/internal/p2p/test_util.go +++ b/internal/p2p/test_util.go @@ -126,7 +126,7 @@ func (sw *Switch) addPeerWithConnection(conn net.Conn) error { } return err } - peerNodeInfo, _, err := pc.conn.Handshake(context.Background(), sw.nodeInfo, sw.nodeKey.PrivKey) + peerNodeInfo, _, err := pc.conn.Handshake(context.Background(), 0, sw.nodeInfo, sw.nodeKey.PrivKey) if err != nil { if err := conn.Close(); err != nil { sw.Logger.Error("Error closing connection", "err", err) diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index a3245dfc8..2e4d26abd 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/internal/p2p/conn" @@ -84,7 +85,7 @@ type Connection interface { // FIXME: The handshake should really be the Router's responsibility, but // that requires the connection interface to be byte-oriented rather than // message-oriented (see comment above). - Handshake(context.Context, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) + Handshake(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) // ReceiveMessage returns the next message received on the connection, // blocking until one is available. Returns io.EOF if closed. diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index eca261476..4d18d896b 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "sync" + "time" "golang.org/x/net/netutil" @@ -255,6 +256,7 @@ func newMConnConnection( // Handshake implements Connection. func (c *mConnConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { @@ -264,6 +266,12 @@ func (c *mConnConnection) Handshake( peerKey crypto.PubKey errCh = make(chan error, 1) ) + handshakeCtx := ctx + if timeout > 0 { + var cancel context.CancelFunc + handshakeCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } // To handle context cancellation, we need to do the handshake in a // goroutine and abort the blocking network calls by closing the connection // when the context is canceled. @@ -276,12 +284,17 @@ func (c *mConnConnection) Handshake( } }() var err error - mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) - errCh <- err + mconn, peerInfo, peerKey, err = c.handshake(handshakeCtx, nodeInfo, privKey) + + select { + case errCh <- err: + case <-handshakeCtx.Done(): + } + }() select { - case <-ctx.Done(): + case <-handshakeCtx.Done(): _ = c.Close() return types.NodeInfo{}, nil, ctx.Err() diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 09a387254..f2d1d0c72 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "time" "github.com/tendermint/tendermint/crypto" tmsync "github.com/tendermint/tendermint/internal/libs/sync" @@ -270,9 +271,16 @@ func (c *MemoryConnection) Status() conn.ConnectionStatus { // Handshake implements Connection. func (c *MemoryConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + select { case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}: c.logger.Debug("sent handshake", "nodeInfo", nodeInfo) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index 1b8ab77f5..de7405177 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -265,7 +265,7 @@ func TestConnection_Handshake(t *testing.T) { errCh := make(chan error, 1) go func() { // Must use assert due to goroutine. - peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey) + peerInfo, peerKey, err := ba.Handshake(ctx, 0, bInfo, bKey) if err == nil { assert.Equal(t, aInfo, peerInfo) assert.Equal(t, aKey.PubKey(), peerKey) @@ -273,7 +273,7 @@ func TestConnection_Handshake(t *testing.T) { errCh <- err }() - peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey) + peerInfo, peerKey, err := ab.Handshake(ctx, 0, aInfo, aKey) require.NoError(t, err) require.Equal(t, bInfo, peerInfo) require.Equal(t, bKey.PubKey(), peerKey) @@ -291,7 +291,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { ab, ba := dialAccept(t, a, b) timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) cancel() - _, _, err := ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) + _, _, err := ab.Handshake(timeoutCtx, 0, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.Canceled, err) _ = ab.Close() @@ -301,7 +301,7 @@ func TestConnection_HandshakeCancel(t *testing.T) { ab, ba = dialAccept(t, a, b) timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) defer cancel() - _, _, err = ab.Handshake(timeoutCtx, types.NodeInfo{}, ed25519.GenPrivKey()) + _, _, err = ab.Handshake(timeoutCtx, 0, types.NodeInfo{}, ed25519.GenPrivKey()) require.Error(t, err) require.Equal(t, context.DeadlineExceeded, err) _ = ab.Close() @@ -630,13 +630,13 @@ func dialAcceptHandshake(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p. go func() { privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ba.Handshake(ctx, nodeInfo, privKey) + _, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey) errCh <- err }() privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ab.Handshake(ctx, nodeInfo, privKey) + _, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey) require.NoError(t, err) timer := time.NewTimer(2 * time.Second) diff --git a/node/node.go b/node/node.go index cd0f31396..e1de314d1 100644 --- a/node/node.go +++ b/node/node.go @@ -1246,7 +1246,9 @@ func createAndStartPrivValidatorGRPCClient( func getRouterConfig(conf *config.Config, proxyApp proxy.AppConns) p2p.RouterOptions { opts := p2p.RouterOptions{ - QueueType: conf.P2P.QueueType, + QueueType: conf.P2P.QueueType, + HandshakeTimeout: conf.P2P.HandshakeTimeout, + DialTimeout: conf.P2P.DialTimeout, } if conf.P2P.MaxNumInboundPeers > 0 {