This commit is contained in:
William Banfield
2022-06-23 15:44:48 -04:00
parent eb97fbbd72
commit 22e80187f6
5 changed files with 33 additions and 18 deletions

View File

@@ -710,14 +710,8 @@ func (r *Router) handshakePeer(
expectID types.NodeID,
) (types.NodeInfo, error) {
if r.options.HandshakeTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.HandshakeTimeout)
defer cancel()
}
nodeInfo := r.nodeInfoProducer()
peerInfo, peerKey, err := conn.Handshake(ctx, *nodeInfo, r.privKey)
peerInfo, peerKey, err := conn.Handshake(ctx, r.options.HandshakeTimeout, *nodeInfo, r.privKey)
if err != nil {
return peerInfo, err
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"time"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/types"
@@ -81,7 +82,7 @@ type Connection interface {
// FIXME: The handshake should really be the Router's responsibility, but
// that requires the connection interface to be byte-oriented rather than
// message-oriented (see comment above).
Handshake(context.Context, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error)
Handshake(context.Context, time.Duration, types.NodeInfo, crypto.PrivKey) (types.NodeInfo, crypto.PubKey, error)
// ReceiveMessage returns the next message received on the connection,
// blocking until one is available. Returns io.EOF if closed.

View File

@@ -9,6 +9,7 @@ import (
"net"
"strconv"
"sync"
"time"
"golang.org/x/net/netutil"
@@ -274,6 +275,7 @@ func newMConnConnection(
// Handshake implements Connection.
func (c *mConnConnection) Handshake(
ctx context.Context,
timeout time.Duration,
nodeInfo types.NodeInfo,
privKey crypto.PrivKey,
) (types.NodeInfo, crypto.PubKey, error) {
@@ -283,6 +285,12 @@ func (c *mConnConnection) Handshake(
peerKey crypto.PubKey
errCh = make(chan error, 1)
)
handshakeCtx := ctx
if timeout > 0 {
var cancel context.CancelFunc
handshakeCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// To handle context cancellation, we need to do the handshake in a
// goroutine and abort the blocking network calls by closing the connection
// when the context is canceled.
@@ -295,17 +303,17 @@ func (c *mConnConnection) Handshake(
}
}()
var err error
mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey)
mconn, peerInfo, peerKey, err = c.handshake(handshakeCtx, nodeInfo, privKey)
select {
case errCh <- err:
case <-ctx.Done():
case <-handshakeCtx.Done():
}
}()
select {
case <-ctx.Done():
case <-handshakeCtx.Done():
_ = c.Close()
return types.NodeInfo{}, nil, ctx.Err()
@@ -314,6 +322,10 @@ func (c *mConnConnection) Handshake(
return types.NodeInfo{}, nil, err
}
c.mconn = mconn
// Start must not use the handshakeCtx. The handshakeCtx may have a
// timeout set that is intended to terminate only the handshake procedure.
// The context passed to Start controls the entire lifecycle of the
// mconn.
if err = c.mconn.Start(ctx); err != nil {
return types.NodeInfo{}, nil, err
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"sync"
"time"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/libs/log"
@@ -273,9 +274,16 @@ func (c *MemoryConnection) RemoteEndpoint() Endpoint {
// Handshake implements Connection.
func (c *MemoryConnection) Handshake(
ctx context.Context,
timeout time.Duration,
nodeInfo types.NodeInfo,
privKey crypto.PrivKey,
) (types.NodeInfo, crypto.PubKey, error) {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
select {
case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)

View File

@@ -296,7 +296,7 @@ func TestConnection_Handshake(t *testing.T) {
errCh := make(chan error, 1)
go func() {
// Must use assert due to goroutine.
peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey)
peerInfo, peerKey, err := ba.Handshake(ctx, 0, bInfo, bKey)
if err == nil {
assert.Equal(t, aInfo, peerInfo)
assert.Equal(t, aKey.PubKey(), peerKey)
@@ -307,7 +307,7 @@ func TestConnection_Handshake(t *testing.T) {
}
}()
peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey)
peerInfo, peerKey, err := ab.Handshake(ctx, 0, aInfo, aKey)
require.NoError(t, err)
require.Equal(t, bInfo, peerInfo)
require.Equal(t, bKey.PubKey(), peerKey)
@@ -353,7 +353,7 @@ func TestConnection_FlushClose(t *testing.T) {
withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, _ := dialAcceptHandshake(ctx, t, a, b)
ab, _ := dialAcceptHandshake(ctx, 0, t, a, b)
err := ab.Close()
require.NoError(t, err)
@@ -374,7 +374,7 @@ func TestConnection_LocalRemoteEndpoint(t *testing.T) {
withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, ba := dialAcceptHandshake(ctx, t, a, b)
ab, ba := dialAcceptHandshake(ctx, 0, t, a, b)
// Local and remote connection endpoints correspond to each other.
require.NotEmpty(t, ab.LocalEndpoint())
@@ -391,7 +391,7 @@ func TestConnection_SendReceive(t *testing.T) {
withTransports(ctx, t, func(ctx context.Context, t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, ba := dialAcceptHandshake(ctx, t, a, b)
ab, ba := dialAcceptHandshake(ctx, 0, t, a, b)
// Can send and receive a to b.
err := ab.SendMessage(ctx, chID, []byte("foo"))
@@ -642,13 +642,13 @@ func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport)
go func() {
privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
_, _, err := ba.Handshake(ctx, nodeInfo, privKey)
_, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey)
errCh <- err
}()
privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
_, _, err := ab.Handshake(ctx, nodeInfo, privKey)
_, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey)
require.NoError(t, err)
timer := time.NewTimer(2 * time.Second)