From a0321633b0a3901d74b79ab617a360df813e4480 Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Fri, 25 Feb 2022 12:37:11 -0500 Subject: [PATCH] p2p: backport changes in ping/pong tolerances (#8009) --- internal/p2p/conn/connection.go | 88 +++++++++-------- internal/p2p/conn/connection_test.go | 135 ++++++++++++++------------- 2 files changed, 117 insertions(+), 106 deletions(-) diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 339ad7469..6a80417ad 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -9,6 +9,7 @@ import ( "net" "reflect" "runtime/debug" + "sync" "sync/atomic" "time" @@ -45,7 +46,7 @@ const ( defaultRecvRate = int64(512000) // 500KB/s defaultSendTimeout = 10 * time.Second defaultPingInterval = 60 * time.Second - defaultPongTimeout = 45 * time.Second + defaultPongTimeout = 90 * time.Second ) type receiveCbFunc func(chID byte, msgBytes []byte) @@ -108,8 +109,10 @@ type MConnection struct { pingTimer *time.Ticker // send pings periodically // close conn if pong is not received in pongTimeout - pongTimer *time.Timer - pongTimeoutCh chan bool // true - timeout, false - peer sent pong + lastMsgRecv struct { + sync.Mutex + at time.Time + } chStatsTimer *time.Ticker // update channel stats periodically @@ -171,10 +174,6 @@ func NewMConnectionWithConfig( onError errorCbFunc, config MConnConfig, ) *MConnection { - if config.PongTimeout >= config.PingInterval { - panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)") - } - mconn := &MConnection{ conn: conn, bufConnReader: bufio.NewReaderSize(conn, minReadBufferSize), @@ -223,16 +222,28 @@ func (c *MConnection) OnStart() error { } c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle) c.pingTimer = time.NewTicker(c.config.PingInterval) - c.pongTimeoutCh = make(chan bool, 1) c.chStatsTimer = time.NewTicker(updateStats) c.quitSendRoutine = make(chan struct{}) c.doneSendRoutine = make(chan struct{}) c.quitRecvRoutine = make(chan struct{}) + c.setRecvLastMsgAt(time.Now()) go c.sendRoutine() go c.recvRoutine() return nil } +func (c *MConnection) setRecvLastMsgAt(t time.Time) { + c.lastMsgRecv.Lock() + defer c.lastMsgRecv.Unlock() + c.lastMsgRecv.at = t +} + +func (c *MConnection) getLastMessageAt() time.Time { + c.lastMsgRecv.Lock() + defer c.lastMsgRecv.Unlock() + return c.lastMsgRecv.at +} + // stopServices stops the BaseService and timers and closes the quitSendRoutine. // if the quitSendRoutine was already closed, it returns true, otherwise it returns false. // It uses the stopMtx to ensure only one of FlushStop and OnStop can do this at a time. @@ -423,6 +434,8 @@ func (c *MConnection) sendRoutine() { defer c._recover() protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter) + pongTimeout := time.NewTicker(c.config.PongTimeout) + defer pongTimeout.Stop() FOR_LOOP: for { var _n int @@ -445,21 +458,7 @@ FOR_LOOP: break SELECTION } c.sendMonitor.Update(_n) - c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) - c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { - select { - case c.pongTimeoutCh <- true: - default: - } - }) c.flush() - case timeout := <-c.pongTimeoutCh: - if timeout { - c.Logger.Debug("Pong timeout") - err = errors.New("pong timeout") - } else { - c.stopPongTimer() - } case <-c.pong: c.Logger.Debug("Send Pong") _n, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) @@ -471,6 +470,14 @@ FOR_LOOP: c.flush() case <-c.quitSendRoutine: break FOR_LOOP + case <-pongTimeout.C: + // the point of the pong timer is to check to + // see if we've seen a message recently, so we + // want to make sure that we escape this + // select statement on an interval to ensure + // that we avoid hanging on to dead + // connections for too long. + break SELECTION case <-c.send: // Send some PacketMsgs eof := c.sendSomePacketMsgs() @@ -483,18 +490,21 @@ FOR_LOOP: } } - if !c.IsRunning() { - break FOR_LOOP + if time.Since(c.getLastMessageAt()) > c.config.PongTimeout { + err = errors.New("pong timeout") } + if err != nil { c.Logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err) c.stopForError(err) break FOR_LOOP } + if !c.IsRunning() { + break FOR_LOOP + } } // Cleanup - c.stopPongTimer() close(c.doneSendRoutine) } @@ -563,6 +573,14 @@ func (c *MConnection) recvRoutine() { FOR_LOOP: for { + select { + case <-c.quitRecvRoutine: + break FOR_LOOP + case <-c.doneSendRoutine: + break FOR_LOOP + default: + } + // Block until .recvMonitor says we can read. c.recvMonitor.Limit(c._maxPacketMsgSize, atomic.LoadInt64(&c.config.RecvRate), true) @@ -605,6 +623,9 @@ FOR_LOOP: break FOR_LOOP } + // record for pong/heartbeat + c.setRecvLastMsgAt(time.Now()) + // Read more depending on packet type. switch pkt := packet.Sum.(type) { case *tmp2p.Packet_PacketPing: @@ -617,12 +638,9 @@ FOR_LOOP: // never block } case *tmp2p.Packet_PacketPong: - c.Logger.Debug("Receive Pong") - select { - case c.pongTimeoutCh <- false: - default: - // never block - } + // do nothing, we updated the "last message + // received" timestamp above, so we can ignore + // this message case *tmp2p.Packet_PacketMsg: channelID := byte(pkt.PacketMsg.ChannelID) channel, ok := c.channelsIdx[channelID] @@ -661,14 +679,6 @@ FOR_LOOP: } } -// not goroutine-safe -func (c *MConnection) stopPongTimer() { - if c.pongTimer != nil { - _ = c.pongTimer.Stop() - c.pongTimer = nil - } -} - // maxPacketMsgSize returns a maximum size of PacketMsg func (c *MConnection) maxPacketMsgSize() int { bz, err := proto.Marshal(mustWrapPacket(&tmp2p.PacketMsg{ diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index 297ea9d18..09e6fe4b0 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -1,7 +1,9 @@ package conn import ( + "context" "encoding/hex" + "io" "net" "testing" "time" @@ -35,8 +37,8 @@ func createMConnectionWithCallbacks( onError func(r interface{}), ) *MConnection { cfg := DefaultMConnConfig() - cfg.PingInterval = 90 * time.Millisecond - cfg.PongTimeout = 45 * time.Millisecond + cfg.PingInterval = 250 * time.Millisecond + cfg.PongTimeout = 500 * time.Millisecond chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) c.SetLogger(log.TestingLogger()) @@ -159,41 +161,43 @@ func TestMConnectionStatus(t *testing.T) { assert.Zero(t, status.Channels[0].SendQueueSize) } -func TestMConnectionPongTimeoutResultsInError(t *testing.T) { +func TestMConnectionWillEventuallyTimeout(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) - onReceive := func(chID byte, msgBytes []byte) { - receivedCh <- msgBytes - } - onError := func(r interface{}) { - errorsCh <- r - } - mconn := createMConnectionWithCallbacks(client, onReceive, onError) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mconn := createMConnectionWithCallbacks(client, nil, nil) err := mconn.Start() - require.Nil(t, err) - t.Cleanup(stopAll(t, mconn)) + require.NoError(t, err) + require.True(t, mconn.IsRunning()) - serverGotPing := make(chan struct{}) go func() { - // read ping - var pkt tmp2p.Packet - _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} - }() - <-serverGotPing + // read the send buffer so that the send receive + // doesn't get blocked. + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() - pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond + for { + select { + case <-ticker.C: + _, _ = io.ReadAll(server) + case <-ctx.Done(): + return + } + } + }() + + // wait for the send routine to die because it doesn't select { - case msgBytes := <-receivedCh: - t.Fatalf("Expected error, but got %v", msgBytes) - case err := <-errorsCh: - assert.NotNil(t, err) - case <-time.After(pongTimerExpired): - t.Fatalf("Expected to receive error after %v", pongTimerExpired) + case <-mconn.doneSendRoutine: + require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout, + "the connection state reflects that we've passed the pong timeout") + // since we hit the timeout, things should be shutdown + require.False(t, mconn.IsRunning()) + case <-time.After(2 * mconn.config.PongTimeout): + t.Fatal("connection did not hit timeout", mconn.config.PongTimeout) } } @@ -226,19 +230,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) - serverGotPing := make(chan struct{}) - go func() { - // read ping (one byte) - var packet tmp2p.Packet - _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping (one byte) + var packet tmp2p.Packet + _, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) - }() - <-serverGotPing + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond select { @@ -299,52 +298,54 @@ func TestMConnectionPingPongs(t *testing.T) { // check that we are not leaking any go-routines t.Cleanup(leaktest.CheckTimeout(t, 10*time.Second)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) receivedCh := make(chan []byte) errorsCh := make(chan interface{}) onReceive := func(chID byte, msgBytes []byte) { - receivedCh <- msgBytes + select { + case <-ctx.Done(): + case receivedCh <- msgBytes: + } } onError := func(r interface{}) { - errorsCh <- r + select { + case errorsCh <- r: + case <-ctx.Done(): + } } mconn := createMConnectionWithCallbacks(client, onReceive, onError) err := mconn.Start() require.Nil(t, err) t.Cleanup(stopAll(t, mconn)) - serverGotPing := make(chan struct{}) - go func() { - protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) - protoWriter := protoio.NewDelimitedWriter(server) - var pkt tmp2p.PacketPing + protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) + protoWriter := protoio.NewDelimitedWriter(server) + var pkt tmp2p.PacketPing - // read ping - _, err = protoReader.ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping + _, err = protoReader.ReadMsg(&pkt) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) - time.Sleep(mconn.config.PingInterval) + time.Sleep(mconn.config.PingInterval) - // read ping - _, err = protoReader.ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping + _, err = protoReader.ReadMsg(&pkt) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) - }() - <-serverGotPing - <-serverGotPing + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) - pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2 + pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4 select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes)