diff --git a/internal/p2p/router.go b/internal/p2p/router.go index e3adc77ee..d5edd42be 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -710,14 +710,8 @@ func (r *Router) handshakePeer( expectID types.NodeID, ) (types.NodeInfo, error) { - if r.options.HandshakeTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, r.options.HandshakeTimeout) - defer cancel() - } - nodeInfo := r.nodeInfoProducer() - peerInfo, peerKey, err := conn.Handshake(ctx, *nodeInfo, r.privKey) + peerInfo, peerKey, err := conn.Handshake(ctx, r.options.HandshakeTimeout, *nodeInfo, r.privKey) if err != nil { return peerInfo, err } diff --git a/internal/p2p/transport.go b/internal/p2p/transport.go index 7a965260a..e644a11ae 100644 --- a/internal/p2p/transport.go +++ b/internal/p2p/transport.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "time" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/types" @@ -81,7 +82,7 @@ type Connection interface { // FIXME: The handshake should really be the Router's responsibility, but // that requires the connection interface to be byte-oriented rather than // message-oriented (see comment above). - Handshake(context.Context, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) + Handshake(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error) // ReceiveMessage returns the next message received on the connection, // blocking until one is available. Returns io.EOF if closed. diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 7bf17d1a0..13a65b973 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "sync" + "time" "golang.org/x/net/netutil" @@ -274,6 +275,7 @@ func newMConnConnection( // Handshake implements Connection. func (c *mConnConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { @@ -283,6 +285,12 @@ func (c *mConnConnection) Handshake( peerKey crypto.PubKey errCh = make(chan error, 1) ) + handshakeCtx := ctx + if timeout > 0 { + var cancel context.CancelFunc + handshakeCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } // To handle context cancellation, we need to do the handshake in a // goroutine and abort the blocking network calls by closing the connection // when the context is canceled. @@ -295,17 +303,17 @@ func (c *mConnConnection) Handshake( } }() var err error - mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) + mconn, peerInfo, peerKey, err = c.handshake(handshakeCtx, nodeInfo, privKey) select { case errCh <- err: - case <-ctx.Done(): + case <-handshakeCtx.Done(): } }() select { - case <-ctx.Done(): + case <-handshakeCtx.Done(): _ = c.Close() return types.NodeInfo{}, nil, ctx.Err() @@ -314,6 +322,10 @@ func (c *mConnConnection) Handshake( return types.NodeInfo{}, nil, err } c.mconn = mconn + // Start must not use the handshakeCtx. The handshakeCtx may have a + // timeout set that is intended to terminate only the handshake procedure. + // The context passed to Start controls the entire lifecycle of the + // mconn. if err = c.mconn.Start(ctx); err != nil { return types.NodeInfo{}, nil, err } diff --git a/internal/p2p/transport_memory.go b/internal/p2p/transport_memory.go index 3eb4c5b51..c321bc174 100644 --- a/internal/p2p/transport_memory.go +++ b/internal/p2p/transport_memory.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "time" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" @@ -273,9 +274,16 @@ func (c *MemoryConnection) RemoteEndpoint() Endpoint { // Handshake implements Connection. func (c *MemoryConnection) Handshake( ctx context.Context, + timeout time.Duration, nodeInfo types.NodeInfo, privKey crypto.PrivKey, ) (types.NodeInfo, crypto.PubKey, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + select { case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}: c.logger.Debug("sent handshake", "nodeInfo", nodeInfo) diff --git a/internal/p2p/transport_test.go b/internal/p2p/transport_test.go index b4edf9bc9..cf7e8a3c4 100644 --- a/internal/p2p/transport_test.go +++ b/internal/p2p/transport_test.go @@ -296,7 +296,7 @@ func TestConnection_Handshake(t *testing.T) { errCh := make(chan error, 1) go func() { // Must use assert due to goroutine. - peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey) + peerInfo, peerKey, err := ba.Handshake(ctx, 0, bInfo, bKey) if err == nil { assert.Equal(t, aInfo, peerInfo) assert.Equal(t, aKey.PubKey(), peerKey) @@ -307,7 +307,7 @@ func TestConnection_Handshake(t *testing.T) { } }() - peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey) + peerInfo, peerKey, err := ab.Handshake(ctx, 0, aInfo, aKey) require.NoError(t, err) require.Equal(t, bInfo, peerInfo) require.Equal(t, bKey.PubKey(), peerKey) @@ -353,7 +353,7 @@ func TestConnection_FlushClose(t *testing.T) { withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, _ := dialAcceptHandshake(ctx, t, a, b) + ab, _ := dialAcceptHandshake(ctx, 0, t, a, b) err := ab.Close() require.NoError(t, err) @@ -374,7 +374,7 @@ func TestConnection_LocalRemoteEndpoint(t *testing.T) { withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, ba := dialAcceptHandshake(ctx, t, a, b) + ab, ba := dialAcceptHandshake(ctx, 0, t, a, b) // Local and remote connection endpoints correspond to each other. require.NotEmpty(t, ab.LocalEndpoint()) @@ -391,7 +391,7 @@ func TestConnection_SendReceive(t *testing.T) { withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) { a := makeTransport(t) b := makeTransport(t) - ab, ba := dialAcceptHandshake(ctx, t, a, b) + ab, ba := dialAcceptHandshake(ctx, 0, t, a, b) // Can send and receive a to b. err := ab.SendMessage(ctx, chID, []byte("foo")) @@ -642,13 +642,13 @@ func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport) go func() { privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ba.Handshake(ctx, nodeInfo, privKey) + _, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey) errCh <- err }() privKey := ed25519.GenPrivKey() nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())} - _, _, err := ab.Handshake(ctx, nodeInfo, privKey) + _, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey) require.NoError(t, err) timer := time.NewTimer(2 * time.Second)