From c62e320ffd71f3974ed1e8c1f32f5ed319825b79 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Mon, 29 Mar 2021 17:07:05 -0400 Subject: [PATCH] p2p: rate-limit incoming connections by IP (#6286) --- p2p/conn_tracker.go | 75 ++++++++++++++++++++++++++++++++++++++++ p2p/conn_tracker_test.go | 73 ++++++++++++++++++++++++++++++++++++++ p2p/router.go | 56 ++++++++++++++++++++++-------- p2p/router_test.go | 3 ++ 4 files changed, 193 insertions(+), 14 deletions(-) create mode 100644 p2p/conn_tracker.go create mode 100644 p2p/conn_tracker_test.go diff --git a/p2p/conn_tracker.go b/p2p/conn_tracker.go new file mode 100644 index 000000000..09673c093 --- /dev/null +++ b/p2p/conn_tracker.go @@ -0,0 +1,75 @@ +package p2p + +import ( + "fmt" + "net" + "sync" + "time" +) + +type connectionTracker interface { + AddConn(net.IP) error + RemoveConn(net.IP) + Len() int +} + +type connTrackerImpl struct { + cache map[string]uint + lastConnect map[string]time.Time + mutex sync.RWMutex + max uint + window time.Duration +} + +func newConnTracker(max uint, window time.Duration) connectionTracker { + return &connTrackerImpl{ + cache: make(map[string]uint), + lastConnect: make(map[string]time.Time), + max: max, + } +} + +func (rat *connTrackerImpl) Len() int { + rat.mutex.RLock() + defer rat.mutex.RUnlock() + return len(rat.cache) +} + +func (rat *connTrackerImpl) AddConn(addr net.IP) error { + address := addr.String() + rat.mutex.Lock() + defer rat.mutex.Unlock() + + if num := rat.cache[address]; num >= rat.max { + return fmt.Errorf("%q has %d connections [max=%d]", address, num, rat.max) + } else if num == 0 { + // if there is already at least connection, check to + // see if it was established before within the window, + // and error if so. + if last := rat.lastConnect[address]; time.Since(last) < rat.window { + return fmt.Errorf("%q tried to connect within window of last %s", address, rat.window) + } + } + + rat.cache[address]++ + rat.lastConnect[address] = time.Now() + + return nil +} + +func (rat *connTrackerImpl) RemoveConn(addr net.IP) { + address := addr.String() + rat.mutex.Lock() + defer rat.mutex.Unlock() + + if num := rat.cache[address]; num > 0 { + rat.cache[address]-- + } + if num := rat.cache[address]; num <= 0 { + delete(rat.cache, address) + } + + if last, ok := rat.lastConnect[address]; ok && time.Since(last) > rat.window { + delete(rat.lastConnect, address) + } +} diff --git a/p2p/conn_tracker_test.go b/p2p/conn_tracker_test.go new file mode 100644 index 000000000..66656e114 --- /dev/null +++ b/p2p/conn_tracker_test.go @@ -0,0 +1,73 @@ +package p2p + +import ( + "math" + "math/rand" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func randByte() byte { + return byte(rand.Intn(math.MaxUint8)) +} + +func randLocalIPv4() net.IP { + return net.IPv4(127, randByte(), randByte(), randByte()) +} + +func TestConnTracker(t *testing.T) { + for name, factory := range map[string]func() connectionTracker{ + "BaseSmall": func() connectionTracker { + return newConnTracker(10, time.Second) + }, + "BaseLarge": func() connectionTracker { + return newConnTracker(100, time.Hour) + }, + } { + t.Run(name, func(t *testing.T) { + factory := factory // nolint:scopelint + t.Run("Initialized", func(t *testing.T) { + ct := factory() + require.Equal(t, 0, ct.Len()) + }) + t.Run("RepeatedAdding", func(t *testing.T) { + ct := factory() + ip := randLocalIPv4() + require.NoError(t, ct.AddConn(ip)) + for i := 0; i < 100; i++ { + _ = ct.AddConn(ip) + } + require.Equal(t, 1, ct.Len()) + }) + t.Run("AddingMany", func(t *testing.T) { + ct := factory() + for i := 0; i < 100; i++ { + _ = ct.AddConn(randLocalIPv4()) + } + require.Equal(t, 100, ct.Len()) + }) + t.Run("Cycle", func(t *testing.T) { + ct := factory() + for i := 0; i < 100; i++ { + ip := randLocalIPv4() + require.NoError(t, ct.AddConn(ip)) + ct.RemoveConn(ip) + } + require.Equal(t, 0, ct.Len()) + }) + }) + } + t.Run("VeryShort", func(t *testing.T) { + ct := newConnTracker(10, time.Microsecond) + for i := 0; i < 10; i++ { + ip := randLocalIPv4() + require.NoError(t, ct.AddConn(ip)) + time.Sleep(2 * time.Microsecond) + require.NoError(t, ct.AddConn(ip)) + } + require.Equal(t, 10, ct.Len()) + }) +} diff --git a/p2p/router.go b/p2p/router.go index beb185a08..2a3235ed6 100644 --- a/p2p/router.go +++ b/p2p/router.go @@ -130,6 +130,15 @@ type RouterOptions struct { // QueueType must be "wdrr" (Weighed Deficit Round Robin), // "priority", or FIFO. Defaults to FIFO. QueueType string + + // MaxIncommingConnectionsPerIP limits the number of incoming + // connections per IP address. Defaults to 100. + MaxIncommingConnectionsPerIP uint + + // IncomingConnectionWindow describes how often an IP address + // can attempt to create a new connection. Defaults to 10 + // milliseconds, and cannot be less than 1 millisecond. + IncomingConnectionWindow time.Duration } const ( @@ -149,6 +158,18 @@ func (o *RouterOptions) Validate() error { return fmt.Errorf("queue type %q is not supported", o.QueueType) } + switch { + case o.IncomingConnectionWindow == 0: + o.IncomingConnectionWindow = 100 * time.Millisecond + case o.IncomingConnectionWindow < time.Millisecond: + return fmt.Errorf("incomming connection window must be grater than 1m [%s]", + o.IncomingConnectionWindow) + } + + if o.MaxIncommingConnectionsPerIP == 0 { + o.MaxIncommingConnectionsPerIP = 100 + } + return nil } @@ -202,6 +223,7 @@ type Router struct { peerManager *PeerManager chDescs []ChannelDescriptor transports []Transport + connTracker connectionTracker protocolTransports map[Protocol]Transport stopCh chan struct{} // signals Router shutdown @@ -235,10 +257,13 @@ func NewRouter( } router := &Router{ - logger: logger, - metrics: metrics, - nodeInfo: nodeInfo, - privKey: privKey, + logger: logger, + metrics: metrics, + nodeInfo: nodeInfo, + privKey: privKey, + connTracker: newConnTracker( + options.MaxIncommingConnectionsPerIP, + options.IncomingConnectionWindow), chDescs: make([]ChannelDescriptor, 0), transports: transports, protocolTransports: map[Protocol]Transport{}, @@ -452,15 +477,6 @@ func (r *Router) acceptPeers(transport Transport) { r.logger.Debug("starting accept routine", "transport", transport) ctx := r.stopCtx() for { - // FIXME: We may need transports to enforce some sort of rate limiting - // here (e.g. by IP address), or alternatively have PeerManager.Accepted() - // do it for us. - // - // FIXME: Even though PeerManager enforces MaxConnected, we may want to - // limit the maximum number of active connections here too, since e.g. - // an adversary can open a ton of connections and then just hang during - // the handshake, taking up TCP socket descriptors. - // // FIXME: The old P2P stack rejected multiple connections for the same IP // unless P2PConfig.AllowDuplicateIP is true -- it's better to limit this // by peer ID rather than IP address, so this hasn't been implemented and @@ -480,9 +496,21 @@ func (r *Router) acceptPeers(transport Transport) { return } + incomingIP := conn.RemoteEndpoint().IP + if err := r.connTracker.AddConn(incomingIP); err != nil { + closeErr := conn.Close() + r.logger.Debug("rate limiting incoming peer", + "err", err, + "ip", incomingIP.String(), + "closeErr", closeErr) + + continue + } + // Spawn a goroutine for the handshake, to avoid head-of-line blocking. go func() { defer conn.Close() + defer r.connTracker.RemoveConn(incomingIP) // FIXME: The peer manager may reject the peer during Accepted() // after we've handshaked with the peer (to find out which peer it @@ -514,7 +542,6 @@ func (r *Router) acceptPeers(transport Transport) { } r.metrics.Peers.Add(1) - queue := r.queueFactory(queueBufferDefault) r.peerMtx.Lock() @@ -692,6 +719,7 @@ func (r *Router) handshakePeer(ctx context.Context, conn Connection, expectID No ctx, cancel = context.WithTimeout(ctx, r.options.HandshakeTimeout) defer cancel() } + peerInfo, peerKey, err := conn.Handshake(ctx, r.nodeInfo, r.privKey) if err != nil { return peerInfo, peerKey, err diff --git a/p2p/router_test.go b/p2p/router_test.go index ac020a7d7..748b4de2e 100644 --- a/p2p/router_test.go +++ b/p2p/router_test.go @@ -334,6 +334,7 @@ func TestRouter_AcceptPeers(t *testing.T) { mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). Return(tc.peerInfo, tc.peerKey, nil) mockConnection.On("Close").Run(func(_ mock.Arguments) { closer.Close() }).Return(nil) + mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) if tc.ok { mockConnection.On("ReceiveMessage").Return(chID, nil, io.EOF) } @@ -462,6 +463,7 @@ func TestRouter_AcceptPeers_HeadOfLineBlocking(t *testing.T) { mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). WaitUntil(closeCh).Return(p2p.NodeInfo{}, nil, io.EOF) mockConnection.On("Close").Return(nil) + mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockTransport := &mocks.Transport{} mockTransport.On("String").Maybe().Return("mock") @@ -661,6 +663,7 @@ func TestRouter_EvictPeers(t *testing.T) { mockConnection.On("Handshake", mock.Anything, selfInfo, selfKey). Return(peerInfo, peerKey.PubKey(), nil) mockConnection.On("ReceiveMessage").WaitUntil(closeCh).Return(chID, nil, io.EOF) + mockConnection.On("RemoteEndpoint").Return(p2p.Endpoint{}) mockConnection.On("Close").Run(func(_ mock.Arguments) { closeOnce.Do(func() { close(closeCh)