diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 32e5ca6b8..3f2ee33a2 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -108,8 +108,10 @@ type MConnection struct { pingTimer *time.Ticker // send pings periodically // close conn if pong is not received in pongTimeout - pongTimer *time.Timer - pongTimeoutCh chan bool // true - timeout, false - peer sent pong + lastMsgRecv struct { + sync.Mutex + at time.Time + } chStatsTimer *time.Ticker // update channel stats periodically @@ -161,10 +163,6 @@ func NewMConnection( onError errorCbFunc, config MConnConfig, ) *MConnection { - if config.PongTimeout >= config.PingInterval { - panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)") - } - mconn := &MConnection{ logger: logger, conn: conn, @@ -205,16 +203,28 @@ func NewMConnection( func (c *MConnection) OnStart(ctx context.Context) error { c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle) c.pingTimer = time.NewTicker(c.config.PingInterval) - c.pongTimeoutCh = make(chan bool, 1) c.chStatsTimer = time.NewTicker(updateStats) c.quitSendRoutine = make(chan struct{}) c.doneSendRoutine = make(chan struct{}) c.quitRecvRoutine = make(chan struct{}) + c.setRecvLastMsgAt(time.Now()) go c.sendRoutine(ctx) go c.recvRoutine(ctx) return nil } +func (c *MConnection) setRecvLastMsgAt(t time.Time) { + c.lastMsgRecv.Lock() + defer c.lastMsgRecv.Unlock() + c.lastMsgRecv.at = t +} + +func (c *MConnection) getLastMessageAt() time.Time { + c.lastMsgRecv.Lock() + defer c.lastMsgRecv.Unlock() + return c.lastMsgRecv.at +} + // stopServices stops the BaseService and timers and closes the quitSendRoutine. // if the quitSendRoutine was already closed, it returns true, otherwise it returns false. // It uses the stopMtx to ensure only one of FlushStop and OnStop can do this at a time. @@ -323,6 +333,8 @@ func (c *MConnection) sendRoutine(ctx context.Context) { defer c._recover(ctx) protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter) + pongTimeout := time.NewTicker(c.config.PongTimeout) + defer pongTimeout.Stop() FOR_LOOP: for { var _n int @@ -344,20 +356,7 @@ FOR_LOOP: break SELECTION } c.sendMonitor.Update(_n) - c.logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) - c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { - select { - case c.pongTimeoutCh <- true: - default: - } - }) c.flush() - case timeout := <-c.pongTimeoutCh: - if timeout { - err = errors.New("pong timeout") - } else { - c.stopPongTimer() - } case <-c.pong: _n, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) if err != nil { @@ -370,6 +369,14 @@ FOR_LOOP: break FOR_LOOP case <-c.quitSendRoutine: break FOR_LOOP + case <-pongTimeout.C: + // the point of the pong timer is to check to + // see if we've seen a message recently, so we + // want to make sure that we escape this + // select statement on an interval to ensure + // that we avoid hanging on to dead + // connections for too long. + break SELECTION case <-c.send: // Send some PacketMsgs eof := c.sendSomePacketMsgs(ctx) @@ -382,18 +389,21 @@ FOR_LOOP: } } - if !c.IsRunning() { - break FOR_LOOP + if time.Since(c.getLastMessageAt()) > c.config.PongTimeout { + err = errors.New("pong timeout") } + if err != nil { c.logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err) c.stopForError(ctx, err) break FOR_LOOP } + if !c.IsRunning() { + break FOR_LOOP + } } // Cleanup - c.stopPongTimer() close(c.doneSendRoutine) } @@ -462,6 +472,14 @@ func (c *MConnection) recvRoutine(ctx context.Context) { FOR_LOOP: for { + select { + case <-ctx.Done(): + break FOR_LOOP + case <-c.doneSendRoutine: + break FOR_LOOP + default: + } + // Block until .recvMonitor says we can read. c.recvMonitor.Limit(c._maxPacketMsgSize, atomic.LoadInt64(&c.config.RecvRate), true) @@ -505,6 +523,9 @@ FOR_LOOP: break FOR_LOOP } + // record for pong/heartbeat + c.setRecvLastMsgAt(time.Now()) + // Read more depending on packet type. switch pkt := packet.Sum.(type) { case *tmp2p.Packet_PacketPing: @@ -516,11 +537,9 @@ FOR_LOOP: // never block } case *tmp2p.Packet_PacketPong: - select { - case c.pongTimeoutCh <- false: - default: - // never block - } + // do nothing, we updated the "last message + // received" timestamp above, so we can ignore + // this message case *tmp2p.Packet_PacketMsg: channelID := ChannelID(pkt.PacketMsg.ChannelID) channel, ok := c.channelsIdx[channelID] @@ -559,14 +578,6 @@ FOR_LOOP: } } -// not goroutine-safe -func (c *MConnection) stopPongTimer() { - if c.pongTimer != nil { - _ = c.pongTimer.Stop() - c.pongTimer = nil - } -} - // maxPacketMsgSize returns a maximum size of PacketMsg func (c *MConnection) maxPacketMsgSize() int { bz, err := proto.Marshal(mustWrapPacket(&tmp2p.PacketMsg{ diff --git a/internal/p2p/conn/connection_test.go b/internal/p2p/conn/connection_test.go index ab05eef21..e68b7584f 100644 --- a/internal/p2p/conn/connection_test.go +++ b/internal/p2p/conn/connection_test.go @@ -3,6 +3,7 @@ package conn import ( "context" "encoding/hex" + "io" "net" "sync" "testing" @@ -39,8 +40,8 @@ func createMConnectionWithCallbacks( onError func(ctx context.Context, r interface{}), ) *MConnection { cfg := DefaultMConnConfig() - cfg.PingInterval = 90 * time.Millisecond - cfg.PongTimeout = 45 * time.Millisecond + cfg.PingInterval = 250 * time.Millisecond + cfg.PongTimeout = 500 * time.Millisecond chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg) return c @@ -160,51 +161,44 @@ func TestMConnectionReceive(t *testing.T) { } } -func TestMConnectionPongTimeoutResultsInError(t *testing.T) { +func TestMConnectionWillEventuallyTimeout(t *testing.T) { server, client := net.Pipe() t.Cleanup(closeAll(t, client, server)) - receivedCh := make(chan []byte) - errorsCh := make(chan interface{}) - onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) { - select { - case receivedCh <- msgBytes: - case <-ctx.Done(): - } - } - onError := func(ctx context.Context, r interface{}) { - select { - case errorsCh <- r: - case <-ctx.Done(): - } - } - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError) + mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, nil, nil) err := mconn.Start(ctx) require.NoError(t, err) t.Cleanup(waitAll(mconn)) + require.True(t, mconn.IsRunning()) - serverGotPing := make(chan struct{}) go func() { - // read ping - var pkt tmp2p.Packet - _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} - }() - <-serverGotPing + // read the send buffer so that the send receive + // doesn't get blocked. + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() - pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond + for { + select { + case <-ticker.C: + _, _ = io.ReadAll(server) + case <-ctx.Done(): + return + } + } + }() + + // wait for the send routine to die because it doesn't select { - case msgBytes := <-receivedCh: - t.Fatalf("Expected error, but got %v", msgBytes) - case err := <-errorsCh: - assert.NotNil(t, err) - case <-time.After(pongTimerExpired): - t.Fatalf("Expected to receive error after %v", pongTimerExpired) + case <-mconn.doneSendRoutine: + require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout, + "the connection state reflects that we've passed the pong timeout") + // since we hit the timeout, things should be shutdown + require.False(t, mconn.IsRunning()) + case <-time.After(2 * mconn.config.PongTimeout): + t.Fatal("connection did not hit timeout", mconn.config.PongTimeout) } } @@ -247,19 +241,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) require.NoError(t, err) - serverGotPing := make(chan struct{}) - go func() { - // read ping (one byte) - var packet tmp2p.Packet - _, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping (one byte) + var packet tmp2p.Packet + _, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) - }() - <-serverGotPing + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond select { @@ -355,36 +344,29 @@ func TestMConnectionPingPongs(t *testing.T) { require.NoError(t, err) t.Cleanup(waitAll(mconn)) - serverGotPing := make(chan struct{}) - go func() { - protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) - protoWriter := protoio.NewDelimitedWriter(server) - var pkt tmp2p.PacketPing + protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize) + protoWriter := protoio.NewDelimitedWriter(server) + var pkt tmp2p.PacketPing - // read ping - _, err = protoReader.ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping + _, err = protoReader.ReadMsg(&pkt) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) - time.Sleep(mconn.config.PingInterval) + time.Sleep(mconn.config.PingInterval) - // read ping - _, err = protoReader.ReadMsg(&pkt) - require.NoError(t, err) - serverGotPing <- struct{}{} + // read ping + _, err = protoReader.ReadMsg(&pkt) + require.NoError(t, err) - // respond with pong - _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) - require.NoError(t, err) - }() - <-serverGotPing - <-serverGotPing + // respond with pong + _, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{})) + require.NoError(t, err) - pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2 + pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4 select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes)