diff --git a/node/node.go b/node/node.go index 70e725d42..5fb0664ea 100644 --- a/node/node.go +++ b/node/node.go @@ -471,6 +471,11 @@ func createTransport( } p2p.MultiplexTransportConnFilters(connFilters...)(transport) + + // Limit the number of incoming connections. + max := config.P2P.MaxNumInboundPeers + len(splitAndTrimEmpty(config.P2P.UnconditionalPeerIDs, ",", " ")) + p2p.MultiplexTransportMaxIncomingConnections(max)(transport) + return transport, peerFilters } diff --git a/p2p/transport.go b/p2p/transport.go index 89d63da7b..6b749c61f 100644 --- a/p2p/transport.go +++ b/p2p/transport.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/net/netutil" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/p2p/conn" @@ -122,11 +123,18 @@ func MultiplexTransportResolver(resolver IPResolver) MultiplexTransportOption { return func(mt *MultiplexTransport) { mt.resolver = resolver } } +// MultiplexTransportMaxIncomingConnections sets the maximum number of +// simultaneous connections (incoming). Default: 0 (unlimited) +func MultiplexTransportMaxIncomingConnections(n int) MultiplexTransportOption { + return func(mt *MultiplexTransport) { mt.maxIncomingConnections = n } +} + // MultiplexTransport accepts and dials tcp connections and upgrades them to // multiplexed peers. type MultiplexTransport struct { - netAddr NetAddress - listener net.Listener + netAddr NetAddress + listener net.Listener + maxIncomingConnections int // see MaxIncomingConnections acceptc chan accept closec chan struct{} @@ -240,6 +248,10 @@ func (mt *MultiplexTransport) Listen(addr NetAddress) error { return err } + if mt.maxIncomingConnections > 0 { + ln = netutil.LimitListener(ln, mt.maxIncomingConnections) + } + mt.netAddr = addr mt.listener = ln diff --git a/p2p/transport_test.go b/p2p/transport_test.go index dd0457fb5..2fc69ce05 100644 --- a/p2p/transport_test.go +++ b/p2p/transport_test.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" "reflect" + "strings" "testing" "time" @@ -134,6 +135,50 @@ func TestTransportMultiplexConnFilterTimeout(t *testing.T) { } } +func TestTransportMultiplexMaxIncomingConnections(t *testing.T) { + mt := newMultiplexTransport( + emptyNodeInfo(), + NodeKey{ + PrivKey: ed25519.GenPrivKey(), + }, + ) + id := mt.nodeKey.ID() + + MultiplexTransportMaxIncomingConnections(0)(mt) + + addr, err := NewNetAddressString(IDAddressString(id, "127.0.0.1:0")) + if err != nil { + t.Fatal(err) + } + + if err := mt.Listen(*addr); err != nil { + t.Fatal(err) + } + + errc := make(chan error) + + go func() { + addr := NewNetAddress(id, mt.listener.Addr()) + + _, err := addr.Dial() + if err != nil { + errc <- err + return + } + + close(errc) + }() + + if err := <-errc; err != nil { + t.Errorf("connection failed: %v", err) + } + + _, err = mt.Accept(peerConfig{}) + if err == nil || !strings.Contains(err.Error(), "connection reset by peer") { + t.Errorf("expected connection reset by peer error, got %v", err) + } +} + func TestTransportMultiplexAcceptMultiple(t *testing.T) { mt := testSetupMultiplexTransport(t) laddr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr())