From 1f39f808e1cb85bb3c0eec4fc503b0cc1c35db19 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Mon, 1 Feb 2021 09:24:31 +0100 Subject: [PATCH] p2p: tighten up and test Transport API (#6020) This tightens up the new P2P `Transport` API and infrastructure, fixes a bunch of bugs and inconsistencies, and adds tests. --- mempool/reactor_test.go | 8 +- p2p/key.go | 7 +- p2p/peer.go | 39 ++- p2p/router.go | 16 +- p2p/router_test.go | 6 +- p2p/switch.go | 3 +- p2p/switch_test.go | 2 +- p2p/transport.go | 121 ++++--- p2p/transport_mconn.go | 236 +++++++------ p2p/transport_mconn_test.go | 208 ++++++++++++ p2p/transport_memory.go | 340 +++++++++---------- p2p/transport_memory_test.go | 129 ++----- p2p/transport_test.go | 637 +++++++++++++++++++++++++++++++++++ 13 files changed, 1272 insertions(+), 480 deletions(-) create mode 100644 p2p/transport_mconn_test.go create mode 100644 p2p/transport_test.go diff --git a/mempool/reactor_test.go b/mempool/reactor_test.go index c878a5dca..d1163bf09 100644 --- a/mempool/reactor_test.go +++ b/mempool/reactor_test.go @@ -39,7 +39,7 @@ type reactorTestSuite struct { func setup(t *testing.T, cfg *cfg.MempoolConfig, logger log.Logger, chBuf uint) *reactorTestSuite { t.Helper() - pID := make([]byte, 16) + pID := make([]byte, 20) _, err := rng.Read(pID) require.NoError(t, err) @@ -313,7 +313,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { func TestMempoolIDsBasic(t *testing.T) { ids := newMempoolIDs() - peerID, err := p2p.NewNodeID("00ffaa") + peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) ids.ReserveForPeer(peerID) @@ -399,7 +399,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) { } }() - peerID, err := p2p.NewNodeID("00ffaa") + peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) // ensure the reactor does not panic (i.e. exhaust active IDs) @@ -427,7 +427,7 @@ func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) { // 0 is already reserved for UnknownPeerID ids := newMempoolIDs() - peerID, err := p2p.NewNodeID("00ffaa") + peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) for i := 0; i < maxActiveIDs-1; i++ { diff --git a/p2p/key.go b/p2p/key.go index b6a53d94e..6b591b04f 100644 --- a/p2p/key.go +++ b/p2p/key.go @@ -22,11 +22,8 @@ type NodeID string // NewNodeID returns a lowercased (normalized) NodeID. func NewNodeID(nodeID string) (NodeID, error) { - if _, err := NodeID(nodeID).Bytes(); err != nil { - return NodeID(""), err - } - - return NodeID(strings.ToLower(nodeID)), nil + n := NodeID(strings.ToLower(nodeID)) + return n, n.Validate() } // NodeIDFromPubKey returns the noe ID corresponding to the given PubKey. It's diff --git a/p2p/peer.go b/p2p/peer.go index 0a70a0f3b..c8e8c345f 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -156,21 +156,26 @@ func (a PeerAddress) Validate() error { // String formats the address as a URL string. func (a PeerAddress) String() string { - // Handle opaque URLs. - if a.Hostname == "" { - s := fmt.Sprintf("%s:%s", a.Protocol, a.NodeID) - if a.Path != "" { - s += "@" + a.Path + u := url.URL{Scheme: string(a.Protocol)} + if a.NodeID != "" { + u.User = url.User(string(a.NodeID)) + } + switch { + case a.Hostname != "": + if a.Port > 0 { + u.Host = net.JoinHostPort(a.Hostname, strconv.Itoa(int(a.Port))) + } else { + u.Host = a.Hostname } - return s + u.Path = a.Path + case a.Protocol != "": + u.Opaque = a.Path // e.g. memory:foo + case a.Path != "" && a.Path[0] != '/': + u.Path = "/" + a.Path // e.g. some/path + default: + u.Path = a.Path // e.g. /some/path } - - s := fmt.Sprintf("%s://%s@%s", a.Protocol, a.NodeID, a.Hostname) - if a.Port > 0 { - s += ":" + strconv.Itoa(int(a.Port)) - } - s += a.Path // We've already normalized the path with appropriate prefix in ParsePeerAddress() - return s + return strings.TrimPrefix(u.String(), "//") } // PeerStatus specifies peer statuses. @@ -1475,12 +1480,12 @@ func (p *peer) processMessages() { p.onError(err) return } - reactor, ok := p.reactors[chID] + reactor, ok := p.reactors[byte(chID)] if !ok { p.onError(fmt.Errorf("unknown channel %v", chID)) return } - reactor.Receive(chID, p, msg) + reactor.Receive(byte(chID), p, msg) } } @@ -1555,7 +1560,7 @@ func (p *peer) Send(chID byte, msgBytes []byte) bool { } else if !p.hasChannel(chID) { return false } - res, err := p.conn.SendMessage(chID, msgBytes) + res, err := p.conn.SendMessage(ChannelID(chID), msgBytes) if err == io.EOF { return false } else if err != nil { @@ -1580,7 +1585,7 @@ func (p *peer) TrySend(chID byte, msgBytes []byte) bool { } else if !p.hasChannel(chID) { return false } - res, err := p.conn.TrySendMessage(chID, msgBytes) + res, err := p.conn.TrySendMessage(ChannelID(chID), msgBytes) if err == io.EOF { return false } else if err != nil { diff --git a/p2p/router.go b/p2p/router.go index 4f8be6924..7c9a6141c 100644 --- a/p2p/router.go +++ b/p2p/router.go @@ -293,10 +293,10 @@ func (r *Router) acceptPeers(transport Transport) { // FIXME: The old P2P stack supported ABCI-based IP address filtering via // /p2p/filter/addr/ queries, do we want to implement this here as well? // Filtering by node ID is probably better. - conn, err := transport.Accept(ctx) + conn, err := transport.Accept() switch err { case nil: - case ErrTransportClosed{}, io.EOF, context.Canceled: + case io.EOF: r.logger.Debug("stopping accept routine", "transport", transport) return default: @@ -536,8 +536,8 @@ func (r *Router) receivePeer(peerID NodeID, conn Connection) error { } r.channelMtx.RLock() - queue, ok := r.channelQueues[ChannelID(chID)] - messageType := r.channelMessages[ChannelID(chID)] + queue, ok := r.channelQueues[chID] + messageType := r.channelMessages[chID] r.channelMtx.RUnlock() if !ok { r.logger.Error("dropping message for unknown channel", "peer", peerID, "channel", chID) @@ -558,8 +558,7 @@ func (r *Router) receivePeer(peerID NodeID, conn Connection) error { } select { - // FIXME: ReceiveMessage() should return ChannelID. - case queue.enqueue() <- Envelope{channelID: ChannelID(chID), From: peerID, Message: msg}: + case queue.enqueue() <- Envelope{channelID: chID, From: peerID, Message: msg}: r.logger.Debug("received message", "peer", peerID, "message", msg) case <-queue.closed(): r.logger.Error("channel closed, dropping message", "peer", peerID, "channel", chID) @@ -580,8 +579,7 @@ func (r *Router) sendPeer(peerID NodeID, conn Connection, queue queue) error { continue } - // FIXME: SendMessage() should take ChannelID. - _, err = conn.SendMessage(byte(envelope.channelID), bz) + _, err = conn.SendMessage(envelope.channelID, bz) if err != nil { return err } @@ -631,6 +629,8 @@ func (r *Router) OnStart() error { } // OnStop implements service.Service. +// +// FIXME: This needs to close transports as well. func (r *Router) OnStop() { // Collect all active queues, so we can wait for them to close. queues := []queue{} diff --git a/p2p/router_test.go b/p2p/router_test.go index f1a18bf10..3b79b672a 100644 --- a/p2p/router_test.go +++ b/p2p/router_test.go @@ -50,8 +50,7 @@ func TestRouter(t *testing.T) { logger := log.TestingLogger() network := p2p.NewMemoryNetwork(logger) nodeInfo, privKey := generateNode() - transport, err := network.CreateTransport(nodeInfo.NodeID) - require.NoError(t, err) + transport := network.CreateTransport(nodeInfo.NodeID) defer transport.Close() chID := p2p.ChannelID(1) @@ -62,8 +61,7 @@ func TestRouter(t *testing.T) { peerManager, err := p2p.NewPeerManager(dbm.NewMemDB(), p2p.PeerManagerOptions{}) require.NoError(t, err) peerInfo, peerKey := generateNode() - peerTransport, err := network.CreateTransport(peerInfo.NodeID) - require.NoError(t, err) + peerTransport := network.CreateTransport(peerInfo.NodeID) defer peerTransport.Close() peerRouter, err := p2p.NewRouter( logger.With("peerID", i), diff --git a/p2p/switch.go b/p2p/switch.go index e345b3572..71f4a0e1f 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -669,8 +669,7 @@ func (sw *Switch) IsPeerPersistent(na *NetAddress) bool { func (sw *Switch) acceptRoutine() { for { var peerNodeInfo NodeInfo - ctx := context.Background() - c, err := sw.transport.Accept(ctx) + c, err := sw.transport.Accept() if err == nil { // NOTE: The legacy MConn transport did handshaking in Accept(), // which was asynchronous and avoided head-of-line-blocking. diff --git a/p2p/switch_test.go b/p2p/switch_test.go index ab8882f83..db83bfd07 100644 --- a/p2p/switch_test.go +++ b/p2p/switch_test.go @@ -706,7 +706,7 @@ func (et errorTransport) Protocols() []Protocol { return []Protocol{"error"} } -func (et errorTransport) Accept(context.Context) (Connection, error) { +func (et errorTransport) Accept() (Connection, error) { return nil, et.acceptErr } func (errorTransport) Dial(context.Context, Endpoint) (Connection, error) { diff --git a/p2p/transport.go b/p2p/transport.go index d278cbeae..373ce7118 100644 --- a/p2p/transport.go +++ b/p2p/transport.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" "net" - "strconv" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/p2p/conn" ) const ( + // defaultProtocol is the default protocol used for PeerAddress when + // a protocol isn't explicitly given as a URL scheme. defaultProtocol Protocol = MConnProtocol ) @@ -20,69 +21,77 @@ type Protocol string // Transport is a connection-oriented mechanism for exchanging data with a peer. type Transport interface { - // Protocols returns the protocols the transport supports, which the - // router uses to pick a transport for a PeerAddress. + // Protocols returns the protocols supported by the transport. The Router + // uses this to pick a transport for an Endpoint. Protocols() []Protocol - // Accept waits for the next inbound connection on a listening endpoint, or - // returns io.EOF if the transport is closed. - Accept(context.Context) (Connection, error) + // Endpoints returns the local endpoints the transport is listening on, if any. + // + // How to listen is transport-dependent, e.g. MConnTransport uses Listen() while + // MemoryTransport starts listening via MemoryNetwork.CreateTransport(). + Endpoints() []Endpoint + + // Accept waits for the next inbound connection on a listening endpoint, blocking + // until either a connection is available or the transport is closed. On closure, + // io.EOF is returned and further Accept calls are futile. + Accept() (Connection, error) // Dial creates an outbound connection to an endpoint. Dial(context.Context, Endpoint) (Connection, error) - // Endpoints lists endpoints the transport is listening on. - Endpoints() []Endpoint - // Close stops accepting new connections, but does not close active connections. Close() error // Stringer is used to display the transport, e.g. in logs. // // Without this, the logger may use reflection to access and display - // internal fields -- these are written concurrently, which can trigger the - // race detector or even cause a panic. + // internal fields. These can be written to concurrently, which can trigger + // the race detector or even cause a panic. fmt.Stringer } // Connection represents an established connection between two endpoints. // -// FIXME: This is a temporary interface while we figure out whether we'll be -// adopting QUIC or not. If we do, this should be a byte-oriented multi-stream -// interface with one goroutine consuming each stream, and the MConnection -// transport either needs protocol changes or a shim. For details, see: +// FIXME: This is a temporary interface for backwards-compatibility with the +// current MConnection-protocol, which is message-oriented. It should be +// migrated to a byte-oriented multi-stream interface instead, which would allow +// e.g. adopting QUIC and making message framing, traffic scheduling, and node +// handshakes a Router concern shared across all transports. However, this +// requires MConnection protocol changes or a shim. For details, see: // https://github.com/tendermint/spec/pull/227 // // FIXME: The interface is currently very broad in order to accommodate -// MConnection behavior that the rest of the P2P stack relies on. This should be -// removed once the P2P core is rewritten. +// MConnection behavior that the legacy P2P stack relies on. It should be +// cleaned up when the legacy stack is removed. type Connection interface { - // Handshake handshakes with the remote peer. It must be called immediately - // after the connection is established, and returns the remote peer's node - // info and public key. The caller is responsible for validation. + // Handshake executes a node handshake with the remote peer. It must be + // called immediately after the connection is established, and returns the + // remote peer's node info and public key. The caller is responsible for + // validation. // - // FIXME: The handshaking should really be the Router's responsibility, but + // FIXME: The handshake should really be the Router's responsibility, but // that requires the connection interface to be byte-oriented rather than // message-oriented (see comment above). Handshake(context.Context, NodeInfo, crypto.PrivKey) (NodeInfo, crypto.PubKey, error) // ReceiveMessage returns the next message received on the connection, - // blocking until one is available. io.EOF is returned when closed. - ReceiveMessage() (chID byte, msg []byte, err error) + // blocking until one is available. Returns io.EOF if closed. + ReceiveMessage() (ChannelID, []byte, error) - // SendMessage sends a message on the connection. - // FIXME: For compatibility with the current Peer, it returns an additional - // boolean false if the message timed out waiting to be accepted into the - // send buffer. - SendMessage(chID byte, msg []byte) (bool, error) + // SendMessage sends a message on the connection. Returns io.EOF if closed. + // + // FIXME: For compatibility with the legacy P2P stack, it returns an + // additional boolean false if the message timed out waiting to be accepted + // into the send buffer. This should be removed. + SendMessage(ChannelID, []byte) (bool, error) // TrySendMessage is a non-blocking version of SendMessage that returns // immediately if the message buffer is full. It returns true if the message // was accepted. // - // FIXME: This is here for backwards-compatibility with the current Peer - // code, and should be removed when possible. - TrySendMessage(chID byte, msg []byte) (bool, error) + // FIXME: This method is here for backwards-compatibility with the legacy + // P2P stack and should be removed. + TrySendMessage(ChannelID, []byte) (bool, error) // LocalEndpoint returns the local endpoint for the connection. LocalEndpoint() Endpoint @@ -98,68 +107,76 @@ type Connection interface { // FIXME: This only exists for backwards-compatibility with the current // MConnection implementation. There should really be a separate Flush() // method, but there is no easy way to synchronously flush pending data with - // the current MConnection structure. + // the current MConnection code. FlushClose() error // Status returns the current connection status. // FIXME: Only here for compatibility with the current Peer code. Status() conn.ConnectionStatus + + // Stringer is used to display the connection, e.g. in logs. + // + // Without this, the logger may use reflection to access and display + // internal fields. These can be written to concurrently, which can trigger + // the race detector or even cause a panic. + fmt.Stringer } // Endpoint represents a transport connection endpoint, either local or remote. +// +// Endpoints are not necessarily networked (see e.g. MemoryTransport) but all +// networked endpoints must use IP as the underlying transport protocol to allow +// e.g. IP address filtering. Either IP or Path (or both) must be set. type Endpoint struct { - // Protocol specifies the transport protocol, used by the router to pick a - // transport for an endpoint. + // Protocol specifies the transport protocol. Protocol Protocol - // Path is an optional, arbitrary transport-specific path or identifier. - Path string - // IP is an IP address (v4 or v6) to connect to. If set, this defines the // endpoint as a networked endpoint. IP net.IP - // Port is a network port (either TCP or UDP). If not set, a default port - // may be used depending on the protocol. + // Port is a network port (either TCP or UDP). If 0, a default port may be + // used depending on the protocol. Port uint16 + + // Path is an optional transport-specific path or identifier. + Path string } -// PeerAddress converts the endpoint into a peer address for a given node ID. +// PeerAddress converts the endpoint into a PeerAddress for the given node ID. func (e Endpoint) PeerAddress(nodeID NodeID) PeerAddress { address := PeerAddress{ NodeID: nodeID, Protocol: e.Protocol, Path: e.Path, } - if e.IP != nil { + if len(e.IP) > 0 { address.Hostname = e.IP.String() address.Port = e.Port } return address } -// String formats an endpoint as a URL string. +// String formats the endpoint as a URL string. func (e Endpoint) String() string { - if e.IP == nil { - return fmt.Sprintf("%s:%s", e.Protocol, e.Path) - } - s := fmt.Sprintf("%s://%s", e.Protocol, e.IP) - if e.Port > 0 { - s += strconv.Itoa(int(e.Port)) - } - s += e.Path - return s + return e.PeerAddress("").String() } -// Validate validates an endpoint. +// Validate validates the endpoint. func (e Endpoint) Validate() error { switch { case e.Protocol == "": return errors.New("endpoint has no protocol") + + case len(e.IP) > 0 && e.IP.To16() == nil: + return fmt.Errorf("invalid IP address %v", e.IP) + case e.Port > 0 && len(e.IP) == 0: return fmt.Errorf("endpoint has port %v but no IP", e.Port) + case len(e.IP) == 0 && e.Path == "": return errors.New("endpoint has neither path nor IP") + default: return nil } diff --git a/p2p/transport_mconn.go b/p2p/transport_mconn.go index 567bdec6f..c0945fb94 100644 --- a/p2p/transport_mconn.go +++ b/p2p/transport_mconn.go @@ -5,9 +5,10 @@ import ( "errors" "fmt" "io" + "math" "net" + "strconv" "sync" - "time" "golang.org/x/net/netutil" @@ -76,51 +77,70 @@ func (m *MConnTransport) Protocols() []Protocol { return []Protocol{MConnProtocol, TCPProtocol} } +// Endpoints implements Transport. +func (m *MConnTransport) Endpoints() []Endpoint { + if m.listener == nil { + return []Endpoint{} + } + select { + case <-m.closeCh: + return []Endpoint{} + default: + } + endpoint := Endpoint{ + Protocol: MConnProtocol, + } + if addr, ok := m.listener.Addr().(*net.TCPAddr); ok { + endpoint.IP = addr.IP + endpoint.Port = uint16(addr.Port) + } + return []Endpoint{endpoint} +} + // Listen asynchronously listens for inbound connections on the given endpoint. // It must be called exactly once before calling Accept(), and the caller must // call Close() to shut down the listener. +// +// FIXME: Listen currently only supports listening on a single endpoint, it +// might be useful to support listening on multiple addresses (e.g. IPv4 and +// IPv6, or a private and public address) via multiple Listen() calls. func (m *MConnTransport) Listen(endpoint Endpoint) error { if m.listener != nil { return errors.New("transport is already listening") } - endpoint, err := m.normalizeEndpoint(endpoint) - if err != nil { - return fmt.Errorf("invalid MConn listen endpoint %q: %w", endpoint, err) + if err := m.validateEndpoint(endpoint); err != nil { + return err } - m.listener, err = net.Listen("tcp", fmt.Sprintf("%v:%v", endpoint.IP, endpoint.Port)) + listener, err := net.Listen("tcp", net.JoinHostPort( + endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) if err != nil { return err } if m.options.MaxAcceptedConnections > 0 { - m.listener = netutil.LimitListener(m.listener, int(m.options.MaxAcceptedConnections)) + // FIXME: This will establish the inbound connection but simply hang it + // until another connection is released. It would probably be better to + // return an error to the remote peer or close the connection. This is + // also a DoS vector since the connection will take up kernel resources. + // This was just carried over from the legacy P2P stack. + listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections)) } + m.listener = listener + return nil } // Accept implements Transport. -func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { +func (m *MConnTransport) Accept() (Connection, error) { if m.listener == nil { return nil, errors.New("transport is not listening") } - if deadline, ok := ctx.Deadline(); ok { - if tcpListener, ok := m.listener.(*net.TCPListener); ok { - // FIXME: This probably needs to have a goroutine that overrides the - // deadline on context cancellation as well. - if err := tcpListener.SetDeadline(deadline); err != nil { - return nil, err - } - } - } - tcpConn, err := m.listener.Accept() if err != nil { select { case <-m.closeCh: return nil, io.EOF - case <-ctx.Done(): - return nil, ctx.Err() default: return nil, err } @@ -131,36 +151,28 @@ func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) { // Dial implements Transport. func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { - endpoint, err := m.normalizeEndpoint(endpoint) - if err != nil { + if err := m.validateEndpoint(endpoint); err != nil { return nil, err } + if endpoint.Port == 0 { + endpoint.Port = 26657 + } dialer := net.Dialer{} - tcpConn, err := dialer.DialContext(ctx, "tcp", - net.JoinHostPort(endpoint.IP.String(), fmt.Sprintf("%v", endpoint.Port))) + tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort( + endpoint.IP.String(), strconv.Itoa(int(endpoint.Port)))) if err != nil { - return nil, err + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, err + } } return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil } -// Endpoints implements Transport. -func (m *MConnTransport) Endpoints() []Endpoint { - if m.listener == nil { - return []Endpoint{} - } - endpoint := Endpoint{ - Protocol: MConnProtocol, - } - if addr, ok := m.listener.Addr().(*net.TCPAddr); ok { - endpoint.IP = addr.IP - endpoint.Port = uint16(addr.Port) - } - return []Endpoint{endpoint} -} - // Close implements Transport. func (m *MConnTransport) Close() error { var err error @@ -173,24 +185,21 @@ func (m *MConnTransport) Close() error { return err } -// normalizeEndpoint normalizes and validates an endpoint. -func (m *MConnTransport) normalizeEndpoint(endpoint Endpoint) (Endpoint, error) { +// validateEndpoint validates an endpoint. +func (m *MConnTransport) validateEndpoint(endpoint Endpoint) error { if err := endpoint.Validate(); err != nil { - return Endpoint{}, err + return err } if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol { - return Endpoint{}, fmt.Errorf("unsupported protocol %q", endpoint.Protocol) + return fmt.Errorf("unsupported protocol %q", endpoint.Protocol) } if len(endpoint.IP) == 0 { - return Endpoint{}, errors.New("endpoint must have an IP address") + return errors.New("endpoint has no IP address") } if endpoint.Path != "" { - return Endpoint{}, fmt.Errorf("endpoint cannot have path (got %q)", endpoint.Path) + return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path) } - if endpoint.Port == 0 { - endpoint.Port = 26657 - } - return endpoint, nil + return nil } // mConnConnection implements Connection for MConnTransport. @@ -209,7 +218,7 @@ type mConnConnection struct { // mConnMessage passes MConnection messages through internal channels. type mConnMessage struct { - channelID byte + channelID ChannelID payload []byte } @@ -226,52 +235,72 @@ func newMConnConnection( mConnConfig: mConnConfig, channelDescs: channelDescs, receiveCh: make(chan mConnMessage), - errorCh: make(chan error), + errorCh: make(chan error, 1), // buffered to avoid onError leak closeCh: make(chan struct{}), } } // Handshake implements Connection. -// -// FIXME: Since the MConnection code panics, we need to recover it and turn it -// into an error. We should remove panics instead. func (c *mConnConnection) Handshake( ctx context.Context, nodeInfo NodeInfo, privKey crypto.PrivKey, -) (peerInfo NodeInfo, peerKey crypto.PubKey, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("recovered from panic: %v", r) - } +) (NodeInfo, crypto.PubKey, error) { + var ( + mconn *conn.MConnection + peerInfo NodeInfo + peerKey crypto.PubKey + errCh = make(chan error, 1) + ) + // To handle context cancellation, we need to do the handshake in a + // goroutine and abort the blocking network calls by closing the connection + // when the context is cancelled. + go func() { + // FIXME: Since the MConnection code panics, we need to recover it and turn it + // into an error. We should remove panics instead. + defer func() { + if r := recover(); r != nil { + errCh <- fmt.Errorf("recovered from panic: %v", r) + } + }() + var err error + mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) + errCh <- err }() - peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey) - return + select { + case <-ctx.Done(): + _ = c.Close() + return NodeInfo{}, nil, ctx.Err() + + case err := <-errCh: + if err != nil { + return NodeInfo{}, nil, err + } + c.mconn = mconn + c.logger = mconn.Logger + if err = c.mconn.Start(); err != nil { + return NodeInfo{}, nil, err + } + return peerInfo, peerKey, nil + } } // handshake is a helper for Handshake, simplifying error handling so we can -// keep panic recovery in Handshake. It sets c.mconn. -// -// FIXME: Move this into Handshake() when MConnection no longer panics. +// keep context handling and panic recovery in Handshake. It returns an +// unstarted but handshaked MConnection, to avoid concurrent field writes. func (c *mConnConnection) handshake( ctx context.Context, nodeInfo NodeInfo, privKey crypto.PrivKey, -) (NodeInfo, crypto.PubKey, error) { +) (*conn.MConnection, NodeInfo, crypto.PubKey, error) { if c.mconn != nil { - return NodeInfo{}, nil, errors.New("connection is already handshaked") - } - - if deadline, ok := ctx.Deadline(); ok { - if err := c.conn.SetDeadline(deadline); err != nil { - return NodeInfo{}, nil, err - } + return nil, NodeInfo{}, nil, errors.New("connection is already handshaked") } secretConn, err := conn.MakeSecretConnection(c.conn, privKey) if err != nil { - return NodeInfo{}, nil, err + return nil, NodeInfo{}, nil, err } var pbPeerInfo p2pproto.NodeInfo @@ -286,20 +315,14 @@ func (c *mConnConnection) handshake( }() for i := 0; i < cap(errCh); i++ { if err = <-errCh; err != nil { - return NodeInfo{}, nil, err + return nil, NodeInfo{}, nil, err } } peerInfo, err := NodeInfoFromProto(&pbPeerInfo) if err != nil { - return NodeInfo{}, nil, err + return nil, NodeInfo{}, nil, err } - if err = c.conn.SetDeadline(time.Time{}); err != nil { - return NodeInfo{}, nil, err - } - - c.logger = c.logger.With("peer", c.RemoteEndpoint().PeerAddress(peerInfo.NodeID)) - mconn := conn.NewMConnectionWithConfig( secretConn, c.channelDescs, @@ -307,31 +330,29 @@ func (c *mConnConnection) handshake( c.onError, c.mConnConfig, ) - mconn.SetLogger(c.logger) - if err = mconn.Start(); err != nil { - return NodeInfo{}, nil, err - } - c.mconn = mconn + mconn.SetLogger(c.logger.With("peer", c.RemoteEndpoint().PeerAddress(peerInfo.NodeID))) - return peerInfo, secretConn.RemotePubKey(), nil + return mconn, peerInfo, secretConn.RemotePubKey(), nil } // onReceive is a callback for MConnection received messages. -func (c *mConnConnection) onReceive(channelID byte, payload []byte) { +func (c *mConnConnection) onReceive(chID byte, payload []byte) { select { - case c.receiveCh <- mConnMessage{channelID: channelID, payload: payload}: + case c.receiveCh <- mConnMessage{channelID: ChannelID(chID), payload: payload}: case <-c.closeCh: } } -// onError is a callback for MConnection errors. The error is passed to errorCh, -// which is only consumed by ReceiveMessage() for parity with the old -// MConnection behavior. +// onError is a callback for MConnection errors. The error is passed via errorCh +// to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior). func (c *mConnConnection) onError(e interface{}) { err, ok := e.(error) if !ok { err = fmt.Errorf("%v", err) } + // We have to close the connection here, since MConnection will have stopped + // the service on any errors. + _ = c.Close() select { case c.errorCh <- err: case <-c.closeCh: @@ -339,37 +360,42 @@ func (c *mConnConnection) onError(e interface{}) { } // String displays connection information. -// FIXME: This is here for backwards compatibility with existing logging, -// it should probably just return RemoteEndpoint().String(), if anything. func (c *mConnConnection) String() string { - endpoint := c.RemoteEndpoint() - return fmt.Sprintf("MConn{%v:%v}", endpoint.IP, endpoint.Port) + return c.RemoteEndpoint().String() } // SendMessage implements Connection. -func (c *mConnConnection) SendMessage(channelID byte, msg []byte) (bool, error) { - // We don't check errorCh here, to preserve old MConnection behavior. +func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) { + if chID > math.MaxUint8 { + return false, fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID) + } select { + case err := <-c.errorCh: + return false, err case <-c.closeCh: return false, io.EOF default: - return c.mconn.Send(channelID, msg), nil + return c.mconn.Send(byte(chID), msg), nil } } // TrySendMessage implements Connection. -func (c *mConnConnection) TrySendMessage(channelID byte, msg []byte) (bool, error) { - // We don't check errorCh here, to preserve old MConnection behavior. +func (c *mConnConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) { + if chID > math.MaxUint8 { + return false, fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID) + } select { + case err := <-c.errorCh: + return false, err case <-c.closeCh: return false, io.EOF default: - return c.mconn.TrySend(channelID, msg), nil + return c.mconn.TrySend(byte(chID), msg), nil } } // ReceiveMessage implements Connection. -func (c *mConnConnection) ReceiveMessage() (byte, []byte, error) { +func (c *mConnConnection) ReceiveMessage() (ChannelID, []byte, error) { select { case err := <-c.errorCh: return 0, nil, err @@ -416,7 +442,7 @@ func (c *mConnConnection) Status() conn.ConnectionStatus { func (c *mConnConnection) Close() error { var err error c.closeOnce.Do(func() { - if c.mconn != nil { + if c.mconn != nil && c.mconn.IsRunning() { err = c.mconn.Stop() } else { err = c.conn.Close() @@ -430,7 +456,7 @@ func (c *mConnConnection) Close() error { func (c *mConnConnection) FlushClose() error { var err error c.closeOnce.Do(func() { - if c.mconn != nil { + if c.mconn != nil && c.mconn.IsRunning() { c.mconn.FlushStop() } else { err = c.conn.Close() diff --git a/p2p/transport_mconn_test.go b/p2p/transport_mconn_test.go new file mode 100644 index 000000000..207b68563 --- /dev/null +++ b/p2p/transport_mconn_test.go @@ -0,0 +1,208 @@ +package p2p_test + +import ( + "io" + "net" + "testing" + "time" + + "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/require" + + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/p2p" + "github.com/tendermint/tendermint/p2p/conn" +) + +// Transports are mainly tested by common tests in transport_test.go, we +// register a transport factory here to get included in those tests. +func init() { + testTransports["mconn"] = func(t *testing.T) p2p.Transport { + transport := p2p.NewMConnTransport( + log.TestingLogger(), + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}}, + p2p.MConnTransportOptions{}, + ) + err := transport.Listen(p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + IP: net.IPv4(127, 0, 0, 1), + Port: 0, // assign a random port + }) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, transport.Close()) + }) + + return transport + } +} + +func TestMConnTransport_AcceptBeforeListen(t *testing.T) { + transport := p2p.NewMConnTransport( + log.TestingLogger(), + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}}, + p2p.MConnTransportOptions{ + MaxAcceptedConnections: 2, + }, + ) + t.Cleanup(func() { + _ = transport.Close() + }) + + _, err := transport.Accept() + require.Error(t, err) + require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close() +} + +func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) { + transport := p2p.NewMConnTransport( + log.TestingLogger(), + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}}, + p2p.MConnTransportOptions{ + MaxAcceptedConnections: 2, + }, + ) + t.Cleanup(func() { + _ = transport.Close() + }) + err := transport.Listen(p2p.Endpoint{ + Protocol: p2p.MConnProtocol, + IP: net.IPv4(127, 0, 0, 1), + }) + require.NoError(t, err) + require.NotEmpty(t, transport.Endpoints()) + endpoint := transport.Endpoints()[0] + + // Start a goroutine to just accept any connections. + acceptCh := make(chan p2p.Connection, 10) + go func() { + for { + conn, err := transport.Accept() + if err != nil { + return + } + acceptCh <- conn + } + }() + + // The first two connections should be accepted just fine. + dial1, err := transport.Dial(ctx, endpoint) + require.NoError(t, err) + defer dial1.Close() + accept1 := <-acceptCh + defer accept1.Close() + require.Equal(t, dial1.LocalEndpoint(), accept1.RemoteEndpoint()) + + dial2, err := transport.Dial(ctx, endpoint) + require.NoError(t, err) + defer dial2.Close() + accept2 := <-acceptCh + defer accept2.Close() + require.Equal(t, dial2.LocalEndpoint(), accept2.RemoteEndpoint()) + + // The third connection will be dialed successfully, but the accept should + // not go through. + dial3, err := transport.Dial(ctx, endpoint) + require.NoError(t, err) + defer dial3.Close() + select { + case <-acceptCh: + require.Fail(t, "unexpected accept") + case <-time.After(time.Second): + } + + // However, once either of the other connections are closed, the accept + // goes through. + require.NoError(t, accept1.Close()) + accept3 := <-acceptCh + defer accept3.Close() + require.Equal(t, dial3.LocalEndpoint(), accept3.RemoteEndpoint()) +} + +func TestMConnTransport_Listen(t *testing.T) { + testcases := []struct { + endpoint p2p.Endpoint + ok bool + }{ + // Valid v4 and v6 addresses, with mconn and tcp protocols. + {p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4(127, 0, 0, 1)}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6zero}, true}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6loopback}, true}, + {p2p.Endpoint{Protocol: p2p.TCPProtocol, IP: net.IPv4zero}, true}, + + // Invalid endpoints. + {p2p.Endpoint{}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false}, + {p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero, Path: "foo"}, false}, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.endpoint.String(), func(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + + transport := p2p.NewMConnTransport( + log.TestingLogger(), + conn.DefaultMConnConfig(), + []*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}}, + p2p.MConnTransportOptions{}, + ) + t.Cleanup(func() { + _ = transport.Close() + }) + + // Transport should not listen on any endpoints yet. + require.Empty(t, transport.Endpoints()) + + // Start listening, and check any expected errors. + err := transport.Listen(tc.endpoint) + if !tc.ok { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Start a goroutine to just accept any connections. + go func() { + for { + conn, err := transport.Accept() + if err != nil { + return + } + defer func() { + _ = conn.Close() + }() + } + }() + + // Check the endpoint. + endpoints := transport.Endpoints() + require.Len(t, endpoints, 1) + endpoint := endpoints[0] + + require.Equal(t, p2p.MConnProtocol, endpoint.Protocol) + if tc.endpoint.IP.IsUnspecified() { + require.True(t, endpoint.IP.IsUnspecified(), + "expected unspecified IP, got %v", endpoint.IP) + } else { + require.True(t, tc.endpoint.IP.Equal(endpoint.IP), + "expected %v, got %v", tc.endpoint.IP, endpoint.IP) + } + require.NotZero(t, endpoint.Port) + require.Empty(t, endpoint.Path) + + // Dialing the endpoint should work. + conn, err := transport.Dial(ctx, endpoint) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + // Trying to listen again should error. + err = transport.Listen(tc.endpoint) + require.Error(t, err) + }) + } +} diff --git a/p2p/transport_memory.go b/p2p/transport_memory.go index 7faa46860..5dd30f257 100644 --- a/p2p/transport_memory.go +++ b/p2p/transport_memory.go @@ -15,11 +15,16 @@ import ( const ( MemoryProtocol Protocol = "memory" + + // bufferSize is the channel buffer size of MemoryConnection. + bufferSize = 1 ) -// MemoryNetwork is an in-memory "network" that uses Go channels to communicate -// between endpoints. Transport endpoints are created with CreateTransport. It -// is primarily used for testing. +// MemoryNetwork is an in-memory "network" that uses buffered Go channels to +// communicate between endpoints. It is primarily meant for testing. +// +// Network endpoints are allocated via CreateTransport(), which takes a node ID, +// and the endpoint is then immediately accessible via the URL "memory:". type MemoryNetwork struct { logger log.Logger @@ -35,19 +40,19 @@ func NewMemoryNetwork(logger log.Logger) *MemoryNetwork { } } -// CreateTransport creates a new memory transport and endpoint with the given -// node ID. It immediately begins listening on the endpoint "memory:", and -// can be accessed by other transports in the same memory network. -func (n *MemoryNetwork) CreateTransport(nodeID NodeID) (*MemoryTransport, error) { +// CreateTransport creates a new memory transport endpoint with the given node +// ID and immediately begins listening on the address "memory:". It panics +// if the node ID is already in use (which is fine, since this is for tests). +func (n *MemoryNetwork) CreateTransport(nodeID NodeID) *MemoryTransport { t := newMemoryTransport(n, nodeID) n.mtx.Lock() defer n.mtx.Unlock() if _, ok := n.transports[nodeID]; ok { - return nil, fmt.Errorf("transport with node ID %q already exists", nodeID) + panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID)) } n.transports[nodeID] = t - return t, nil + return t } // GetTransport looks up a transport in the network, returning nil if not found. @@ -58,7 +63,7 @@ func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport { } // RemoveTransport removes a transport from the network and closes it. -func (n *MemoryNetwork) RemoveTransport(id NodeID) error { +func (n *MemoryNetwork) RemoveTransport(id NodeID) { n.mtx.Lock() t, ok := n.transports[id] delete(n.transports, id) @@ -67,39 +72,46 @@ func (n *MemoryNetwork) RemoveTransport(id NodeID) error { if ok { // Close may recursively call RemoveTransport() again, but this is safe // because we've already removed the transport from the map above. - return t.Close() + if err := t.Close(); err != nil { + n.logger.Error("failed to close memory transport", "id", id, "err", err) + } } - return nil } -// MemoryTransport is an in-memory transport that's primarily meant for testing. -// It communicates between endpoints using Go channels. To dial a different -// endpoint, both endpoints/transports must be in the same MemoryNetwork. +// Size returns the number of transports in the network. +func (n *MemoryNetwork) Size() int { + return len(n.transports) +} + +// MemoryTransport is an in-memory transport that uses buffered Go channels to +// communicate between endpoints. It is primarily meant for testing. +// +// New transports are allocated with MemoryNetwork.CreateTransport(). To contact +// a different endpoint, both transports must be in the same MemoryNetwork. type MemoryTransport struct { + logger log.Logger network *MemoryNetwork nodeID NodeID - logger log.Logger acceptCh chan *MemoryConnection closeCh chan struct{} closeOnce sync.Once } -// newMemoryTransport creates a new in-memory transport in the given network. -// Callers should use MemoryNetwork.CreateTransport() or GenerateTransport() -// to create transports, this is for internal use by MemoryNetwork. +// newMemoryTransport creates a new MemoryTransport. This is for internal use by +// MemoryNetwork, use MemoryNetwork.CreateTransport() instead. func newMemoryTransport(network *MemoryNetwork, nodeID NodeID) *MemoryTransport { return &MemoryTransport{ + logger: network.logger.With("local", nodeID), network: network, nodeID: nodeID, - logger: network.logger.With("local", fmt.Sprintf("%v:%v", MemoryProtocol, nodeID)), acceptCh: make(chan *MemoryConnection), closeCh: make(chan struct{}), } } -// String displays the transport. +// String implements Transport. func (t *MemoryTransport) String() string { return string(MemoryProtocol) } @@ -109,95 +121,6 @@ func (t *MemoryTransport) Protocols() []Protocol { return []Protocol{MemoryProtocol} } -// Accept implements Transport. -func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { - select { - case conn := <-t.acceptCh: - t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint()) - return conn, nil - case <-t.closeCh: - return nil, io.EOF - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Dial implements Transport. -func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { - if endpoint.Protocol != MemoryProtocol { - return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) - } - if endpoint.Path == "" { - return nil, errors.New("no path") - } - nodeID, err := NewNodeID(endpoint.Path) - if err != nil { - return nil, err - } - t.logger.Info("dialing peer", "remote", endpoint) - - peerTransport := t.network.GetTransport(nodeID) - if peerTransport == nil { - return nil, fmt.Errorf("unknown peer %q", nodeID) - } - inCh := make(chan memoryMessage, 1) - outCh := make(chan memoryMessage, 1) - closer := tmsync.NewCloser() - - outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closer) - inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closer) - - select { - case peerTransport.acceptCh <- inConn: - return outConn, nil - case <-peerTransport.closeCh: - return nil, ErrTransportClosed{} - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// DialAccept is a convenience function that dials a peer MemoryTransport and -// returns both ends of the connection (A to B and B to A). -func (t *MemoryTransport) DialAccept( - ctx context.Context, - peer *MemoryTransport, -) (Connection, Connection, error) { - endpoints := peer.Endpoints() - if len(endpoints) == 0 { - return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeID) - } - - acceptCh := make(chan Connection, 1) - errCh := make(chan error, 1) - go func() { - conn, err := peer.Accept(ctx) - errCh <- err - acceptCh <- conn - }() - - outConn, err := t.Dial(ctx, endpoints[0]) - if err != nil { - return nil, nil, err - } - if err = <-errCh; err != nil { - return nil, nil, err - } - inConn := <-acceptCh - - return outConn, inConn, nil -} - -// Close implements Transport. -func (t *MemoryTransport) Close() error { - err := t.network.RemoveTransport(t.nodeID) - t.closeOnce.Do(func() { - close(t.closeCh) - }) - t.logger.Info("stopped accepting connections") - return err -} - // Endpoints implements Transport. func (t *MemoryTransport) Endpoints() []Endpoint { select { @@ -211,47 +134,127 @@ func (t *MemoryTransport) Endpoints() []Endpoint { } } -// MemoryConnection is an in-memory connection between two transports (nodes). +// Accept implements Transport. +func (t *MemoryTransport) Accept() (Connection, error) { + select { + case conn := <-t.acceptCh: + t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path) + return conn, nil + case <-t.closeCh: + return nil, io.EOF + } +} + +// Dial implements Transport. +func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { + if endpoint.Protocol != MemoryProtocol { + return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) + } + if endpoint.Path == "" { + return nil, errors.New("no path") + } + nodeID, err := NewNodeID(endpoint.Path) + if err != nil { + return nil, err + } + + t.logger.Info("dialing peer", "remote", nodeID) + peer := t.network.GetTransport(nodeID) + if peer == nil { + return nil, fmt.Errorf("unknown peer %q", nodeID) + } + + inCh := make(chan memoryMessage, bufferSize) + outCh := make(chan memoryMessage, bufferSize) + closer := tmsync.NewCloser() + + outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer) + inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer) + + select { + case peer.acceptCh <- inConn: + return outConn, nil + case <-peer.closeCh: + return nil, io.EOF + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Close implements Transport. +func (t *MemoryTransport) Close() error { + t.network.RemoveTransport(t.nodeID) + t.closeOnce.Do(func() { + close(t.closeCh) + t.logger.Info("closed transport") + }) + return nil +} + +// MemoryConnection is an in-memory connection between two transport endpoints. type MemoryConnection struct { - logger log.Logger - local *MemoryTransport - remote *MemoryTransport + logger log.Logger + localID NodeID + remoteID NodeID receiveCh <-chan memoryMessage sendCh chan<- memoryMessage closer *tmsync.Closer } -// memoryMessage is used to pass messages internally in the connection. -// For handshakes, nodeInfo and pubKey are set instead of channel and message. +// memoryMessage is passed internally, containing either a message or handshake. type memoryMessage struct { - channel byte - message []byte + channelID ChannelID + message []byte // For handshakes. - nodeInfo NodeInfo + nodeInfo *NodeInfo pubKey crypto.PubKey } -// newMemoryConnection creates a new MemoryConnection. It takes all channels -// (including the closeCh signal channel) on construction, such that they can be -// shared between both ends of the connection. +// newMemoryConnection creates a new MemoryConnection. func newMemoryConnection( - local *MemoryTransport, - remote *MemoryTransport, + logger log.Logger, + localID NodeID, + remoteID NodeID, receiveCh <-chan memoryMessage, sendCh chan<- memoryMessage, closer *tmsync.Closer, ) *MemoryConnection { - c := &MemoryConnection{ - local: local, - remote: remote, + return &MemoryConnection{ + logger: logger.With("remote", remoteID), + localID: localID, + remoteID: remoteID, receiveCh: receiveCh, sendCh: sendCh, closer: closer, } - c.logger = c.local.logger.With("remote", c.RemoteEndpoint()) - return c +} + +// String implements Connection. +func (c *MemoryConnection) String() string { + return c.RemoteEndpoint().String() +} + +// LocalEndpoint implements Connection. +func (c *MemoryConnection) LocalEndpoint() Endpoint { + return Endpoint{ + Protocol: MemoryProtocol, + Path: string(c.localID), + } +} + +// RemoteEndpoint implements Connection. +func (c *MemoryConnection) RemoteEndpoint() Endpoint { + return Endpoint{ + Protocol: MemoryProtocol, + Path: string(c.remoteID), + } +} + +// Status implements Connection. +func (c *MemoryConnection) Status() conn.ConnectionStatus { + return conn.ConnectionStatus{} } // Handshake implements Connection. @@ -261,27 +264,32 @@ func (c *MemoryConnection) Handshake( privKey crypto.PrivKey, ) (NodeInfo, crypto.PubKey, error) { select { - case c.sendCh <- memoryMessage{nodeInfo: nodeInfo, pubKey: privKey.PubKey()}: - case <-ctx.Done(): - return NodeInfo{}, nil, ctx.Err() + case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}: + c.logger.Debug("sent handshake", "nodeInfo", nodeInfo) case <-c.closer.Done(): return NodeInfo{}, nil, io.EOF + case <-ctx.Done(): + return NodeInfo{}, nil, ctx.Err() } select { case msg := <-c.receiveCh: - c.logger.Debug("handshake complete") - return msg.nodeInfo, msg.pubKey, nil - case <-ctx.Done(): - return NodeInfo{}, nil, ctx.Err() + if msg.nodeInfo == nil { + return NodeInfo{}, nil, errors.New("no NodeInfo in handshake") + } + c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo) + return *msg.nodeInfo, msg.pubKey, nil case <-c.closer.Done(): return NodeInfo{}, nil, io.EOF + case <-ctx.Done(): + return NodeInfo{}, nil, ctx.Err() } } // ReceiveMessage implements Connection. -func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) { - // check close first, since channels are buffered +func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) { + // Check close first, since channels are buffered. Otherwise, below select + // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return 0, nil, io.EOF @@ -290,16 +298,17 @@ func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) { select { case msg := <-c.receiveCh: - c.logger.Debug("received message", "channel", msg.channel, "message", msg.message) - return msg.channel, msg.message, nil + c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message) + return msg.channelID, msg.message, nil case <-c.closer.Done(): return 0, nil, io.EOF } } // SendMessage implements Connection. -func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) { - // check close first, since channels are buffered +func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) { + // Check close first, since channels are buffered. Otherwise, below select + // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return false, io.EOF @@ -307,8 +316,8 @@ func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) { } select { - case c.sendCh <- memoryMessage{channel: chID, message: msg}: - c.logger.Debug("sent message", "channel", chID, "message", msg) + case c.sendCh <- memoryMessage{channelID: chID, message: msg}: + c.logger.Debug("sent message", "chID", chID, "msg", msg) return true, nil case <-c.closer.Done(): return false, io.EOF @@ -316,8 +325,9 @@ func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) { } // TrySendMessage implements Connection. -func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) { - // check close first, since channels are buffered +func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) { + // Check close first, since channels are buffered. Otherwise, below select + // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return false, io.EOF @@ -325,8 +335,8 @@ func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) { } select { - case c.sendCh <- memoryMessage{channel: chID, message: msg}: - c.logger.Debug("sent message", "channel", chID, "message", msg) + case c.sendCh <- memoryMessage{channelID: chID, message: msg}: + c.logger.Debug("sent message", "chID", chID, "msg", msg) return true, nil case <-c.closer.Done(): return false, io.EOF @@ -335,35 +345,19 @@ func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) { } } -// Close closes the connection. +// Close implements Connection. func (c *MemoryConnection) Close() error { - c.closer.Close() - c.logger.Info("closed connection") + select { + case <-c.closer.Done(): + return nil + default: + c.closer.Close() + c.logger.Info("closed connection") + } return nil } -// FlushClose flushes all pending sends and then closes the connection. +// FlushClose implements Connection. func (c *MemoryConnection) FlushClose() error { return c.Close() } - -// LocalEndpoint returns the local endpoint for the connection. -func (c *MemoryConnection) LocalEndpoint() Endpoint { - return Endpoint{ - Protocol: MemoryProtocol, - Path: string(c.local.nodeID), - } -} - -// RemoteEndpoint returns the remote endpoint for the connection. -func (c *MemoryConnection) RemoteEndpoint() Endpoint { - return Endpoint{ - Protocol: MemoryProtocol, - Path: string(c.remote.nodeID), - } -} - -// Status returns the current connection status. -func (c *MemoryConnection) Status() conn.ConnectionStatus { - return conn.ConnectionStatus{} -} diff --git a/p2p/transport_memory_test.go b/p2p/transport_memory_test.go index 6b79005e8..31f896d19 100644 --- a/p2p/transport_memory_test.go +++ b/p2p/transport_memory_test.go @@ -1,8 +1,8 @@ package p2p_test import ( - "context" - "io" + "bytes" + "encoding/hex" "testing" "github.com/stretchr/testify/require" @@ -10,114 +10,25 @@ import ( "github.com/tendermint/tendermint/p2p" ) -func TestMemoryTransport(t *testing.T) { - ctx := context.Background() - network := p2p.NewMemoryNetwork(log.TestingLogger()) - a, err := network.CreateTransport("0a") - require.NoError(t, err) - b, err := network.CreateTransport("0b") - require.NoError(t, err) - c, err := network.CreateTransport("0c") - require.NoError(t, err) +// Transports are mainly tested by common tests in transport_test.go, we +// register a transport factory here to get included in those tests. +func init() { + var network *p2p.MemoryNetwork // shared by transports in the same test - // Dialing a missing endpoint should fail. - _, err = a.Dial(ctx, p2p.Endpoint{ - Protocol: p2p.MemoryProtocol, - Path: "foo", - }) - require.Error(t, err) + testTransports["memory"] = func(t *testing.T) p2p.Transport { + if network == nil { + network = p2p.NewMemoryNetwork(log.TestingLogger()) + } + i := byte(network.Size()) + nodeID, err := p2p.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20))) + require.NoError(t, err) + transport := network.CreateTransport(nodeID) - // Dialing and accepting a→b and a→c should work. - aToB, bToA, err := a.DialAccept(ctx, b) - require.NoError(t, err) - defer aToB.Close() - defer bToA.Close() + t.Cleanup(func() { + require.NoError(t, transport.Close()) + network = nil // set up a new memory network for the next test + }) - aToC, cToA, err := a.DialAccept(ctx, c) - require.NoError(t, err) - defer aToC.Close() - defer cToA.Close() - - // Send and receive a message both ways a→b and b→a - sent, err := aToB.SendMessage(1, []byte{0x01}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err := bToA.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x01}, msg) - - sent, err = bToA.SendMessage(1, []byte{0x02}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err = aToB.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x02}, msg) - - // Send and receive a message both ways a→c and c→a - sent, err = aToC.SendMessage(1, []byte{0x03}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err = cToA.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x03}, msg) - - sent, err = cToA.SendMessage(1, []byte{0x04}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err = aToC.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x04}, msg) - - // If we close aToB, sending and receiving on either end will error. - err = aToB.Close() - require.NoError(t, err) - - _, err = aToB.SendMessage(1, []byte{0x05}) - require.Equal(t, io.EOF, err) - - _, _, err = aToB.ReceiveMessage() - require.Equal(t, io.EOF, err) - - _, err = bToA.SendMessage(1, []byte{0x06}) - require.Equal(t, io.EOF, err) - - _, _, err = bToA.ReceiveMessage() - require.Equal(t, io.EOF, err) - - // We can still send aToC. - sent, err = aToC.SendMessage(1, []byte{0x07}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err = cToA.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x07}, msg) - - // If we close the c transport, it will no longer accept connections, - // but we can still use the open connection. - endpoint := c.Endpoints()[0] - err = c.Close() - require.NoError(t, err) - require.Empty(t, c.Endpoints()) - - _, err = a.Dial(ctx, endpoint) - require.Error(t, err) - - sent, err = aToC.SendMessage(1, []byte{0x08}) - require.NoError(t, err) - require.True(t, sent) - - ch, msg, err = cToA.ReceiveMessage() - require.NoError(t, err) - require.EqualValues(t, 1, ch) - require.EqualValues(t, []byte{0x08}, msg) + return transport + } } diff --git a/p2p/transport_test.go b/p2p/transport_test.go new file mode 100644 index 000000000..1b30d15a0 --- /dev/null +++ b/p2p/transport_test.go @@ -0,0 +1,637 @@ +package p2p_test + +import ( + "context" + "io" + "net" + "testing" + "time" + + "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tendermint/tendermint/crypto/ed25519" + "github.com/tendermint/tendermint/libs/bytes" + "github.com/tendermint/tendermint/p2p" +) + +// transportFactory is used to set up transports for tests. +type transportFactory func(t *testing.T) p2p.Transport + +var ( + ctx = context.Background() // convenience context + chID = p2p.ChannelID(1) // channel ID for use in tests + testTransports = map[string]transportFactory{} // registry for withTransports +) + +// withTransports is a test helper that runs a test against all transports +// registered in testTransports. +func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) { + t.Helper() + for name, transportFactory := range testTransports { + transportFactory := transportFactory + t.Run(name, func(t *testing.T) { + t.Cleanup(leaktest.Check(t)) + tester(t, transportFactory) + }) + } +} + +func TestTransport_AcceptClose(t *testing.T) { + // Just test accept unblock on close, happy path is tested widely elsewhere. + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + + // In-progress Accept should error on concurrent close. + errCh := make(chan error, 1) + go func() { + time.Sleep(200 * time.Millisecond) + errCh <- a.Close() + }() + + _, err := a.Accept() + require.Error(t, err) + require.Equal(t, io.EOF, err) + require.NoError(t, <-errCh) + + // Closed transport should return error immediately. + _, err = a.Accept() + require.Error(t, err) + require.Equal(t, io.EOF, err) + }) +} + +func TestTransport_DialEndpoints(t *testing.T) { + ipTestCases := []struct { + ip net.IP + ok bool + }{ + {net.IPv4zero, true}, + {net.IPv6zero, true}, + + {nil, false}, + {net.IPv4bcast, false}, + {net.IPv4allsys, false}, + {[]byte{1, 2, 3}, false}, + {[]byte{1, 2, 3, 4, 5}, false}, + } + + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + endpoints := a.Endpoints() + require.NotEmpty(t, endpoints) + endpoint := endpoints[0] + + // Spawn a goroutine to simply accept any connections until closed. + go func() { + for { + conn, err := a.Accept() + if err != nil { + return + } + _ = conn.Close() + } + }() + + // Dialing self should work. + conn, err := a.Dial(ctx, endpoint) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + // Dialing empty endpoint should error. + _, err = a.Dial(ctx, p2p.Endpoint{}) + require.Error(t, err) + + // Dialing without protocol should error. + noProtocol := endpoint + noProtocol.Protocol = "" + _, err = a.Dial(ctx, noProtocol) + require.Error(t, err) + + // Dialing with invalid protocol should error. + fooProtocol := endpoint + fooProtocol.Protocol = "foo" + _, err = a.Dial(ctx, fooProtocol) + require.Error(t, err) + + // Tests for networked endpoints (with IP). + if len(endpoint.IP) > 0 { + for _, tc := range ipTestCases { + tc := tc + t.Run(tc.ip.String(), func(t *testing.T) { + e := endpoint + e.IP = tc.ip + conn, err := a.Dial(ctx, e) + if tc.ok { + require.NoError(t, conn.Close()) + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } + + // Non-networked endpoints should error. + noIP := endpoint + noIP.IP = nil + noIP.Port = 0 + noIP.Path = "foo" + _, err := a.Dial(ctx, noIP) + require.Error(t, err) + + } else { + // Tests for non-networked endpoints (no IP). + noPath := endpoint + noPath.Path = "" + _, err = a.Dial(ctx, noPath) + require.Error(t, err) + } + }) +} + +func TestTransport_Dial(t *testing.T) { + // Most just tests dial failures, happy path is tested widely elsewhere. + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + + require.NotEmpty(t, a.Endpoints()) + require.NotEmpty(t, b.Endpoints()) + aEndpoint := a.Endpoints()[0] + bEndpoint := b.Endpoints()[0] + + // Context cancellation should error. We can't test timeouts since we'd + // need a non-responsive endpoint. + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + _, err := a.Dial(cancelCtx, bEndpoint) + require.Error(t, err) + require.Equal(t, err, context.Canceled) + + // Unavailable endpoint should error. + err = b.Close() + require.NoError(t, err) + _, err = a.Dial(ctx, bEndpoint) + require.Error(t, err) + + // Dialing from a closed transport should still work. + errCh := make(chan error, 1) + go func() { + conn, err := a.Accept() + if err == nil { + _ = conn.Close() + } + errCh <- err + }() + conn, err := b.Dial(ctx, aEndpoint) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.NoError(t, <-errCh) + }) +} + +func TestTransport_Endpoints(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + + // Both transports return valid and different endpoints. + aEndpoints := a.Endpoints() + bEndpoints := b.Endpoints() + require.NotEmpty(t, aEndpoints) + require.NotEmpty(t, bEndpoints) + require.NotEqual(t, aEndpoints, bEndpoints) + for _, endpoint := range append(aEndpoints, bEndpoints...) { + err := endpoint.Validate() + require.NoError(t, err, "invalid endpoint %q", endpoint) + } + + // When closed, the transport should no longer return any endpoints. + err := a.Close() + require.NoError(t, err) + require.Empty(t, a.Endpoints()) + require.NotEmpty(t, b.Endpoints()) + }) +} + +func TestTransport_Protocols(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + protocols := a.Protocols() + endpoints := a.Endpoints() + require.NotEmpty(t, protocols) + require.NotEmpty(t, endpoints) + + for _, endpoint := range endpoints { + require.Contains(t, protocols, endpoint.Protocol) + } + }) +} + +func TestTransport_String(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + require.NotEmpty(t, a.String()) + }) +} + +func TestConnection_Handshake(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, ba := dialAccept(t, a, b) + + // A handshake should pass the given keys and NodeInfo. + aKey := ed25519.GenPrivKey() + aInfo := p2p.NodeInfo{ + NodeID: p2p.NodeIDFromPubKey(aKey.PubKey()), + ProtocolVersion: p2p.NewProtocolVersion(1, 2, 3), + ListenAddr: "listenaddr", + Network: "network", + Version: "1.2.3", + Channels: bytes.HexBytes([]byte{0xf0, 0x0f}), + Moniker: "moniker", + Other: p2p.NodeInfoOther{ + TxIndex: "txindex", + RPCAddress: "rpc.domain.com", + }, + } + bKey := ed25519.GenPrivKey() + bInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(bKey.PubKey())} + + errCh := make(chan error, 1) + go func() { + // Must use assert due to goroutine. + peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey) + if err == nil { + assert.Equal(t, aInfo, peerInfo) + assert.Equal(t, aKey.PubKey(), peerKey) + } + errCh <- err + }() + + peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey) + require.NoError(t, err) + require.Equal(t, bInfo, peerInfo) + require.Equal(t, bKey.PubKey(), peerKey) + + require.NoError(t, <-errCh) + }) +} + +func TestConnection_HandshakeCancel(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + + // Handshake should error on context cancellation. + ab, ba := dialAccept(t, a, b) + timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + cancel() + _, _, err := ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey()) + require.Error(t, err) + require.Equal(t, context.Canceled, err) + _ = ab.Close() + _ = ba.Close() + + // Handshake should error on context timeout. + ab, ba = dialAccept(t, a, b) + timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond) + defer cancel() + _, _, err = ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey()) + require.Error(t, err) + require.Equal(t, context.DeadlineExceeded, err) + _ = ab.Close() + _ = ba.Close() + }) +} + +func TestConnection_FlushClose(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, _ := dialAcceptHandshake(t, a, b) + + // FIXME: FlushClose should be removed (and replaced by separate Flush + // and Close calls if necessary). We can't reliably test it, so we just + // make sure it closes both ends and that it's idempotent. + err := ab.FlushClose() + require.NoError(t, err) + + _, _, err = ab.ReceiveMessage() + require.Error(t, err) + require.Equal(t, io.EOF, err) + + _, err = ab.SendMessage(chID, []byte("closed")) + require.Error(t, err) + require.Equal(t, io.EOF, err) + + err = ab.FlushClose() + require.NoError(t, err) + }) +} + +func TestConnection_LocalRemoteEndpoint(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, ba := dialAcceptHandshake(t, a, b) + + // Local and remote connection endpoints correspond to each other. + require.NotEmpty(t, ab.LocalEndpoint()) + require.NotEmpty(t, ba.LocalEndpoint()) + require.Equal(t, ab.LocalEndpoint(), ba.RemoteEndpoint()) + require.Equal(t, ab.RemoteEndpoint(), ba.LocalEndpoint()) + }) +} + +func TestConnection_SendReceive(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, ba := dialAcceptHandshake(t, a, b) + + // Can send and receive a to b. + ok, err := ab.SendMessage(chID, []byte("foo")) + require.NoError(t, err) + require.True(t, ok) + + ch, msg, err := ba.ReceiveMessage() + require.NoError(t, err) + require.Equal(t, []byte("foo"), msg) + require.Equal(t, chID, ch) + + // Can send and receive b to a. + _, err = ba.SendMessage(chID, []byte("bar")) + require.NoError(t, err) + + _, msg, err = ab.ReceiveMessage() + require.NoError(t, err) + require.Equal(t, []byte("bar"), msg) + + // TrySendMessage also works. + ok, err = ba.TrySendMessage(chID, []byte("try")) + require.NoError(t, err) + require.True(t, ok) + + ch, msg, err = ab.ReceiveMessage() + require.NoError(t, err) + require.Equal(t, []byte("try"), msg) + require.Equal(t, chID, ch) + + // Connections should still be active after closing the transports. + err = a.Close() + require.NoError(t, err) + err = b.Close() + require.NoError(t, err) + + _, err = ab.SendMessage(chID, []byte("still here")) + require.NoError(t, err) + ch, msg, err = ba.ReceiveMessage() + require.NoError(t, err) + require.Equal(t, chID, ch) + require.Equal(t, []byte("still here"), msg) + + // Close one side of the connection. Both sides should then error + // with io.EOF when trying to send or receive. + err = ba.Close() + require.NoError(t, err) + + _, _, err = ab.ReceiveMessage() + require.Error(t, err) + require.Equal(t, io.EOF, err) + _, err = ab.SendMessage(chID, []byte("closed")) + require.Error(t, err) + require.Equal(t, io.EOF, err) + + _, _, err = ba.ReceiveMessage() + require.Error(t, err) + require.Equal(t, io.EOF, err) + _, err = ba.SendMessage(chID, []byte("closed")) + require.Error(t, err) + require.Equal(t, io.EOF, err) + }) +} + +func TestConnection_Status(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, _ := dialAcceptHandshake(t, a, b) + + // FIXME: This isn't implemented in all transports, so for now we just + // check that it doesn't panic, which isn't really much of a test. + ab.Status() + }) +} + +func TestConnection_String(t *testing.T) { + withTransports(t, func(t *testing.T, makeTransport transportFactory) { + a := makeTransport(t) + b := makeTransport(t) + ab, _ := dialAccept(t, a, b) + require.NotEmpty(t, ab.String()) + }) +} + +func TestEndpoint_PeerAddress(t *testing.T) { + var ( + ip4 = []byte{1, 2, 3, 4} + ip4in6 = net.IPv4(1, 2, 3, 4) + ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ) + + testcases := []struct { + endpoint p2p.Endpoint + expect p2p.PeerAddress + }{ + // Valid endpoints. + { + p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, + p2p.PeerAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, + }, + { + p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"}, + p2p.PeerAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"}, + }, + { + p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"}, + p2p.PeerAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"}, + }, + { + p2p.Endpoint{Protocol: "memory", Path: "foo"}, + p2p.PeerAddress{Protocol: "memory", Path: "foo"}, + }, + + // Partial (invalid) endpoints. + {p2p.Endpoint{}, p2p.PeerAddress{}}, + {p2p.Endpoint{Protocol: "tcp"}, p2p.PeerAddress{Protocol: "tcp"}}, + {p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.PeerAddress{Hostname: "1.2.3.4"}}, + {p2p.Endpoint{Port: 8080}, p2p.PeerAddress{}}, + {p2p.Endpoint{Path: "path"}, p2p.PeerAddress{Path: "path"}}, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.endpoint.String(), func(t *testing.T) { + // Without NodeID. + expect := tc.expect + require.Equal(t, expect, tc.endpoint.PeerAddress("")) + + // With NodeID. + expect.NodeID = p2p.NodeID("b10c") + require.Equal(t, expect, tc.endpoint.PeerAddress(expect.NodeID)) + }) + } +} + +func TestEndpoint_String(t *testing.T) { + var ( + ip4 = []byte{1, 2, 3, 4} + ip4in6 = net.IPv4(1, 2, 3, 4) + ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ) + + testcases := []struct { + endpoint p2p.Endpoint + expect string + }{ + // Non-networked endpoints. + {p2p.Endpoint{Protocol: "memory", Path: "foo"}, "memory:foo"}, + {p2p.Endpoint{Protocol: "memory", Path: "👋"}, "memory:👋"}, + + // IPv4 endpoints. + {p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"}, + + // IPv6 endpoints. + {p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"}, + {p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"}, + + // Partial (invalid) endpoints. + {p2p.Endpoint{}, ""}, + {p2p.Endpoint{Protocol: "tcp"}, "tcp:"}, + {p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"}, + {p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"}, + {p2p.Endpoint{Port: 8080}, ""}, + {p2p.Endpoint{Path: "foo"}, "/foo"}, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.expect, func(t *testing.T) { + require.Equal(t, tc.expect, tc.endpoint.String()) + }) + } +} + +func TestEndpoint_Validate(t *testing.T) { + var ( + ip4 = []byte{1, 2, 3, 4} + ip4in6 = net.IPv4(1, 2, 3, 4) + ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01} + ) + + testcases := []struct { + endpoint p2p.Endpoint + expectValid bool + }{ + // Valid endpoints. + {p2p.Endpoint{Protocol: "tcp", IP: ip4}, true}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true}, + {p2p.Endpoint{Protocol: "tcp", IP: ip6}, true}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true}, + {p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true}, + {p2p.Endpoint{Protocol: "memory", Path: "path"}, true}, + + // Invalid endpoints. + {p2p.Endpoint{}, false}, + {p2p.Endpoint{IP: ip4}, false}, + {p2p.Endpoint{Protocol: "tcp"}, false}, + {p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false}, + {p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false}, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.endpoint.String(), func(t *testing.T) { + err := tc.endpoint.Validate() + if tc.expectValid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +// dialAccept is a helper that dials b from a and returns both sides of the +// connection. +func dialAccept(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { + t.Helper() + + endpoints := b.Endpoints() + require.NotEmpty(t, endpoints, "peer not listening on any endpoints") + + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + acceptCh := make(chan p2p.Connection, 1) + errCh := make(chan error, 1) + go func() { + conn, err := b.Accept() + errCh <- err + acceptCh <- conn + }() + + dialConn, err := a.Dial(ctx, endpoints[0]) + require.NoError(t, err) + + acceptConn := <-acceptCh + require.NoError(t, <-errCh) + + t.Cleanup(func() { + _ = dialConn.Close() + _ = acceptConn.Close() + }) + + return dialConn, acceptConn +} + +// dialAcceptHandshake is a helper that dials and handshakes b from a and +// returns both sides of the connection. +func dialAcceptHandshake(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) { + t.Helper() + + ab, ba := dialAccept(t, a, b) + + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + errCh := make(chan error, 1) + go func() { + privKey := ed25519.GenPrivKey() + nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())} + _, _, err := ba.Handshake(ctx, nodeInfo, privKey) + errCh <- err + }() + + privKey := ed25519.GenPrivKey() + nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())} + _, _, err := ab.Handshake(ctx, nodeInfo, privKey) + require.NoError(t, err) + + timer := time.NewTimer(2 * time.Second) + defer timer.Stop() + select { + case err := <-errCh: + require.NoError(t, err) + case <-timer.C: + require.Fail(t, "handshake timed out") + } + + return ab, ba +}