mirror of
https://github.com/tendermint/tendermint.git
synced 2026-04-19 15:25:11 +00:00
p2p: mconn track last message for pongs (#7995)
* p2p: mconn track last message for pongs * fix spell * cr feedback * test fix part one * cleanup tests * fix comment Co-authored-by: M. J. Fromberger <fromberger@interchain.io>
This commit is contained in:
@@ -108,8 +108,10 @@ type MConnection struct {
|
||||
pingTimer *time.Ticker // send pings periodically
|
||||
|
||||
// close conn if pong is not received in pongTimeout
|
||||
pongTimer *time.Timer
|
||||
pongTimeoutCh chan bool // true - timeout, false - peer sent pong
|
||||
lastMsgRecv struct {
|
||||
sync.Mutex
|
||||
at time.Time
|
||||
}
|
||||
|
||||
chStatsTimer *time.Ticker // update channel stats periodically
|
||||
|
||||
@@ -161,10 +163,6 @@ func NewMConnection(
|
||||
onError errorCbFunc,
|
||||
config MConnConfig,
|
||||
) *MConnection {
|
||||
if config.PongTimeout >= config.PingInterval {
|
||||
panic("pongTimeout must be less than pingInterval (otherwise, next ping will reset pong timer)")
|
||||
}
|
||||
|
||||
mconn := &MConnection{
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
@@ -205,16 +203,28 @@ func NewMConnection(
|
||||
func (c *MConnection) OnStart(ctx context.Context) error {
|
||||
c.flushTimer = timer.NewThrottleTimer("flush", c.config.FlushThrottle)
|
||||
c.pingTimer = time.NewTicker(c.config.PingInterval)
|
||||
c.pongTimeoutCh = make(chan bool, 1)
|
||||
c.chStatsTimer = time.NewTicker(updateStats)
|
||||
c.quitSendRoutine = make(chan struct{})
|
||||
c.doneSendRoutine = make(chan struct{})
|
||||
c.quitRecvRoutine = make(chan struct{})
|
||||
c.setRecvLastMsgAt(time.Now())
|
||||
go c.sendRoutine(ctx)
|
||||
go c.recvRoutine(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MConnection) setRecvLastMsgAt(t time.Time) {
|
||||
c.lastMsgRecv.Lock()
|
||||
defer c.lastMsgRecv.Unlock()
|
||||
c.lastMsgRecv.at = t
|
||||
}
|
||||
|
||||
func (c *MConnection) getLastMessageAt() time.Time {
|
||||
c.lastMsgRecv.Lock()
|
||||
defer c.lastMsgRecv.Unlock()
|
||||
return c.lastMsgRecv.at
|
||||
}
|
||||
|
||||
// stopServices stops the BaseService and timers and closes the quitSendRoutine.
|
||||
// if the quitSendRoutine was already closed, it returns true, otherwise it returns false.
|
||||
// It uses the stopMtx to ensure only one of FlushStop and OnStop can do this at a time.
|
||||
@@ -323,6 +333,8 @@ func (c *MConnection) sendRoutine(ctx context.Context) {
|
||||
defer c._recover(ctx)
|
||||
protoWriter := protoio.NewDelimitedWriter(c.bufConnWriter)
|
||||
|
||||
pongTimeout := time.NewTicker(c.config.PongTimeout)
|
||||
defer pongTimeout.Stop()
|
||||
FOR_LOOP:
|
||||
for {
|
||||
var _n int
|
||||
@@ -344,20 +356,7 @@ FOR_LOOP:
|
||||
break SELECTION
|
||||
}
|
||||
c.sendMonitor.Update(_n)
|
||||
c.logger.Debug("Starting pong timer", "dur", c.config.PongTimeout)
|
||||
c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() {
|
||||
select {
|
||||
case c.pongTimeoutCh <- true:
|
||||
default:
|
||||
}
|
||||
})
|
||||
c.flush()
|
||||
case timeout := <-c.pongTimeoutCh:
|
||||
if timeout {
|
||||
err = errors.New("pong timeout")
|
||||
} else {
|
||||
c.stopPongTimer()
|
||||
}
|
||||
case <-c.pong:
|
||||
_n, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
if err != nil {
|
||||
@@ -370,6 +369,14 @@ FOR_LOOP:
|
||||
break FOR_LOOP
|
||||
case <-c.quitSendRoutine:
|
||||
break FOR_LOOP
|
||||
case <-pongTimeout.C:
|
||||
// the point of the pong timer is to check to
|
||||
// see if we've seen a message recently, so we
|
||||
// want to make sure that we escape this
|
||||
// select statement on an interval to ensure
|
||||
// that we avoid hanging on to dead
|
||||
// connections for too long.
|
||||
break SELECTION
|
||||
case <-c.send:
|
||||
// Send some PacketMsgs
|
||||
eof := c.sendSomePacketMsgs(ctx)
|
||||
@@ -382,18 +389,21 @@ FOR_LOOP:
|
||||
}
|
||||
}
|
||||
|
||||
if !c.IsRunning() {
|
||||
break FOR_LOOP
|
||||
if time.Since(c.getLastMessageAt()) > c.config.PongTimeout {
|
||||
err = errors.New("pong timeout")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.logger.Error("Connection failed @ sendRoutine", "conn", c, "err", err)
|
||||
c.stopForError(ctx, err)
|
||||
break FOR_LOOP
|
||||
}
|
||||
if !c.IsRunning() {
|
||||
break FOR_LOOP
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
c.stopPongTimer()
|
||||
close(c.doneSendRoutine)
|
||||
}
|
||||
|
||||
@@ -462,6 +472,14 @@ func (c *MConnection) recvRoutine(ctx context.Context) {
|
||||
|
||||
FOR_LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break FOR_LOOP
|
||||
case <-c.doneSendRoutine:
|
||||
break FOR_LOOP
|
||||
default:
|
||||
}
|
||||
|
||||
// Block until .recvMonitor says we can read.
|
||||
c.recvMonitor.Limit(c._maxPacketMsgSize, atomic.LoadInt64(&c.config.RecvRate), true)
|
||||
|
||||
@@ -505,6 +523,9 @@ FOR_LOOP:
|
||||
break FOR_LOOP
|
||||
}
|
||||
|
||||
// record for pong/heartbeat
|
||||
c.setRecvLastMsgAt(time.Now())
|
||||
|
||||
// Read more depending on packet type.
|
||||
switch pkt := packet.Sum.(type) {
|
||||
case *tmp2p.Packet_PacketPing:
|
||||
@@ -516,11 +537,9 @@ FOR_LOOP:
|
||||
// never block
|
||||
}
|
||||
case *tmp2p.Packet_PacketPong:
|
||||
select {
|
||||
case c.pongTimeoutCh <- false:
|
||||
default:
|
||||
// never block
|
||||
}
|
||||
// do nothing, we updated the "last message
|
||||
// received" timestamp above, so we can ignore
|
||||
// this message
|
||||
case *tmp2p.Packet_PacketMsg:
|
||||
channelID := ChannelID(pkt.PacketMsg.ChannelID)
|
||||
channel, ok := c.channelsIdx[channelID]
|
||||
@@ -559,14 +578,6 @@ FOR_LOOP:
|
||||
}
|
||||
}
|
||||
|
||||
// not goroutine-safe
|
||||
func (c *MConnection) stopPongTimer() {
|
||||
if c.pongTimer != nil {
|
||||
_ = c.pongTimer.Stop()
|
||||
c.pongTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// maxPacketMsgSize returns a maximum size of PacketMsg
|
||||
func (c *MConnection) maxPacketMsgSize() int {
|
||||
bz, err := proto.Marshal(mustWrapPacket(&tmp2p.PacketMsg{
|
||||
|
||||
@@ -3,6 +3,7 @@ package conn
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -39,8 +40,8 @@ func createMConnectionWithCallbacks(
|
||||
onError func(ctx context.Context, r interface{}),
|
||||
) *MConnection {
|
||||
cfg := DefaultMConnConfig()
|
||||
cfg.PingInterval = 90 * time.Millisecond
|
||||
cfg.PongTimeout = 45 * time.Millisecond
|
||||
cfg.PingInterval = 250 * time.Millisecond
|
||||
cfg.PongTimeout = 500 * time.Millisecond
|
||||
chDescs := []*ChannelDescriptor{{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
|
||||
c := NewMConnection(logger, conn, chDescs, onReceive, onError, cfg)
|
||||
return c
|
||||
@@ -160,51 +161,44 @@ func TestMConnectionReceive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMConnectionPongTimeoutResultsInError(t *testing.T) {
|
||||
func TestMConnectionWillEventuallyTimeout(t *testing.T) {
|
||||
server, client := net.Pipe()
|
||||
t.Cleanup(closeAll(t, client, server))
|
||||
|
||||
receivedCh := make(chan []byte)
|
||||
errorsCh := make(chan interface{})
|
||||
onReceive := func(ctx context.Context, chID ChannelID, msgBytes []byte) {
|
||||
select {
|
||||
case receivedCh <- msgBytes:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
onError := func(ctx context.Context, r interface{}) {
|
||||
select {
|
||||
case errorsCh <- r:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, onReceive, onError)
|
||||
mconn := createMConnectionWithCallbacks(log.TestingLogger(), client, nil, nil)
|
||||
err := mconn.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(waitAll(mconn))
|
||||
require.True(t, mconn.IsRunning())
|
||||
|
||||
serverGotPing := make(chan struct{})
|
||||
go func() {
|
||||
// read ping
|
||||
var pkt tmp2p.Packet
|
||||
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&pkt)
|
||||
require.NoError(t, err)
|
||||
serverGotPing <- struct{}{}
|
||||
}()
|
||||
<-serverGotPing
|
||||
// read the send buffer so that the send receive
|
||||
// doesn't get blocked.
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
pongTimerExpired := mconn.config.PongTimeout + 200*time.Millisecond
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
_, _ = io.ReadAll(server)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the send routine to die because it doesn't
|
||||
select {
|
||||
case msgBytes := <-receivedCh:
|
||||
t.Fatalf("Expected error, but got %v", msgBytes)
|
||||
case err := <-errorsCh:
|
||||
assert.NotNil(t, err)
|
||||
case <-time.After(pongTimerExpired):
|
||||
t.Fatalf("Expected to receive error after %v", pongTimerExpired)
|
||||
case <-mconn.doneSendRoutine:
|
||||
require.True(t, time.Since(mconn.getLastMessageAt()) > mconn.config.PongTimeout,
|
||||
"the connection state reflects that we've passed the pong timeout")
|
||||
// since we hit the timeout, things should be shutdown
|
||||
require.False(t, mconn.IsRunning())
|
||||
case <-time.After(2 * mconn.config.PongTimeout):
|
||||
t.Fatal("connection did not hit timeout", mconn.config.PongTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,19 +241,14 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) {
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
|
||||
serverGotPing := make(chan struct{})
|
||||
go func() {
|
||||
// read ping (one byte)
|
||||
var packet tmp2p.Packet
|
||||
_, err := protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
|
||||
require.NoError(t, err)
|
||||
serverGotPing <- struct{}{}
|
||||
// read ping (one byte)
|
||||
var packet tmp2p.Packet
|
||||
_, err = protoio.NewDelimitedReader(server, maxPingPongPacketSize).ReadMsg(&packet)
|
||||
require.NoError(t, err)
|
||||
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
<-serverGotPing
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
|
||||
pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond
|
||||
select {
|
||||
@@ -355,36 +344,29 @@ func TestMConnectionPingPongs(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(waitAll(mconn))
|
||||
|
||||
serverGotPing := make(chan struct{})
|
||||
go func() {
|
||||
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
|
||||
protoWriter := protoio.NewDelimitedWriter(server)
|
||||
var pkt tmp2p.PacketPing
|
||||
protoReader := protoio.NewDelimitedReader(server, maxPingPongPacketSize)
|
||||
protoWriter := protoio.NewDelimitedWriter(server)
|
||||
var pkt tmp2p.PacketPing
|
||||
|
||||
// read ping
|
||||
_, err = protoReader.ReadMsg(&pkt)
|
||||
require.NoError(t, err)
|
||||
serverGotPing <- struct{}{}
|
||||
// read ping
|
||||
_, err = protoReader.ReadMsg(&pkt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(mconn.config.PingInterval)
|
||||
time.Sleep(mconn.config.PingInterval)
|
||||
|
||||
// read ping
|
||||
_, err = protoReader.ReadMsg(&pkt)
|
||||
require.NoError(t, err)
|
||||
serverGotPing <- struct{}{}
|
||||
// read ping
|
||||
_, err = protoReader.ReadMsg(&pkt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
<-serverGotPing
|
||||
<-serverGotPing
|
||||
// respond with pong
|
||||
_, err = protoWriter.WriteMsg(mustWrapPacket(&tmp2p.PacketPong{}))
|
||||
require.NoError(t, err)
|
||||
|
||||
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2
|
||||
pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 4
|
||||
select {
|
||||
case msgBytes := <-receivedCh:
|
||||
t.Fatalf("Expected no data, but got %v", msgBytes)
|
||||
|
||||
Reference in New Issue
Block a user