From f4ff66de30232bb5c3f70d390843bbcab5543e8e Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 24 Jan 2018 14:41:31 +0400 Subject: [PATCH] rewrite pong timer to use time.AfterFunc --- p2p/conn/connection.go | 42 +++++++------- p2p/conn/connection_test.go | 107 ++++++++++++++++++++++++++++++++++-- 2 files changed, 126 insertions(+), 23 deletions(-) diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index 71d8608fb..4e461b35e 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -83,11 +83,15 @@ type MConnection struct { errored uint32 config *MConnConfig - quit chan struct{} - flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. - pingTimer *cmn.RepeatTimer // send pings periodically - pongTimer *time.Timer // close conn if pong is not received in pongTimeout - chStatsTimer *cmn.RepeatTimer // update channel stats periodically + quit chan struct{} + flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. + pingTimer *cmn.RepeatTimer // send pings periodically + + // close conn if pong is not received in pongTimeout + pongTimer *time.Timer + pongTimeoutCh chan struct{} + + chStatsTimer *cmn.RepeatTimer // update channel stats periodically created time.Time // time of creation } @@ -187,10 +191,7 @@ func (c *MConnection) OnStart() error { c.quit = make(chan struct{}) c.flushTimer = cmn.NewThrottleTimer("flush", c.config.FlushThrottle) c.pingTimer = cmn.NewRepeatTimer("ping", c.config.PingInterval) - c.pongTimer = time.NewTimer(c.config.PongTimeout) - // we start timer once we've send ping; needed here because we use start - // listening in recvRoutine - _ = c.pongTimer.Stop() + c.pongTimeoutCh = make(chan struct{}) c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateStats) go c.sendRoutine() go c.recvRoutine() @@ -200,13 +201,12 @@ func (c *MConnection) OnStart() error { // OnStop implements BaseService func (c *MConnection) OnStop() { c.BaseService.OnStop() - c.flushTimer.Stop() - c.pingTimer.Stop() - _ = c.pongTimer.Stop() - c.chStatsTimer.Stop() if c.quit != nil { close(c.quit) } + c.flushTimer.Stop() + c.pingTimer.Stop() + c.chStatsTimer.Stop() c.conn.Close() // nolint: errcheck // We can't close pong safely here because @@ -337,12 +337,13 @@ FOR_LOOP: c.Logger.Debug("Send Ping") wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) c.sendMonitor.Update(int(n)) + c.Logger.Debug("Starting pong timer", "dur", c.config.PongTimeout) + c.pongTimer = time.AfterFunc(c.config.PongTimeout, func() { + c.pongTimeoutCh <- struct{}{} + }) c.flush() - c.Logger.Debug("Starting pong timer") - c.pongTimer.Reset(c.config.PongTimeout) - case <-c.pongTimer.C: + case <-c.pongTimeoutCh: c.Logger.Debug("Pong timeout") - // XXX: should we decrease peer score instead of closing connection? err = errors.New("pong timeout") case <-c.pong: c.Logger.Debug("Send Pong") @@ -350,6 +351,9 @@ FOR_LOOP: c.sendMonitor.Update(int(n)) c.flush() case <-c.quit: + if c.pongTimer != nil { + _ = c.pongTimer.Stop() + } break FOR_LOOP case <-c.send: // Send some msgPackets @@ -483,8 +487,8 @@ FOR_LOOP: } case packetTypePong: c.Logger.Debug("Receive Pong") - if !c.pongTimer.Stop() { - <-c.pongTimer.C + if c.pongTimer != nil { + _ = c.pongTimer.Stop() } case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) diff --git a/p2p/conn/connection_test.go b/p2p/conn/connection_test.go index d505805ed..acfa8032a 100644 --- a/p2p/conn/connection_test.go +++ b/p2p/conn/connection_test.go @@ -24,7 +24,7 @@ func createTestMConnection(conn net.Conn) *MConnection { func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection { chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} cfg := DefaultMConnConfig() - cfg.PingInterval = 60 * time.Millisecond + cfg.PingInterval = 90 * time.Millisecond cfg.PongTimeout = 45 * time.Millisecond c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, cfg) c.SetLogger(log.TestingLogger()) @@ -137,19 +137,118 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { require.Nil(t, err) defer mconn.Stop() + serverGotPing := make(chan struct{}) go func() { // read ping server.Read(make([]byte, 1)) + serverGotPing <- struct{}{} }() + <-serverGotPing - expectErrorAfter := (mconn.config.PingInterval + mconn.config.PongTimeout) * 2 + pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond select { case msgBytes := <-receivedCh: t.Fatalf("Expected error, but got %v", msgBytes) case err := <-errorsCh: assert.NotNil(t, err) - case <-time.After(expectErrorAfter): - t.Fatalf("Expected to receive error after %v", expectErrorAfter) + case <-time.After(pongTimerExpired): + t.Fatalf("Expected to receive error after %v", pongTimerExpired) + } +} + +func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + // sending 3 pongs in a row + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + + serverGotPing := make(chan struct{}) + go func() { + // read ping + server.Read(make([]byte, 1)) + serverGotPing <- struct{}{} + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + }() + <-serverGotPing + + pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond + select { + case msgBytes := <-receivedCh: + t.Fatalf("Expected no data, but got %v", msgBytes) + case err := <-errorsCh: + t.Fatalf("Expected no error, but got %v", err) + case <-time.After(pongTimerExpired): + assert.True(t, mconn.IsRunning()) + } +} + +func TestMConnectionPingPongs(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + serverGotPing := make(chan struct{}) + go func() { + // read ping + server.Read(make([]byte, 1)) + serverGotPing <- struct{}{} + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + + time.Sleep(mconn.config.PingInterval) + + // read ping + server.Read(make([]byte, 1)) + // respond with pong + _, err = server.Write([]byte{packetTypePong}) + require.Nil(t, err) + }() + <-serverGotPing + + pongTimerExpired := (mconn.config.PongTimeout + 10*time.Millisecond) * 2 + select { + case msgBytes := <-receivedCh: + t.Fatalf("Expected no data, but got %v", msgBytes) + case err := <-errorsCh: + t.Fatalf("Expected no error, but got %v", err) + case <-time.After(2 * pongTimerExpired): + assert.True(t, mconn.IsRunning()) } }