diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index b02e1ccb4..6fbb425e7 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -351,12 +351,7 @@ FOR_LOOP: c.sendMonitor.Update(int(n)) c.flush() case <-c.quit: - if c.pongTimer != nil { - if !c.pongTimer.Stop() { - <-c.pongTimer.C - } - drain(c.pongTimeoutCh) - } + c.stopPongTimer() break FOR_LOOP case <-c.send: // Send some msgPackets @@ -482,6 +477,7 @@ FOR_LOOP: switch pktType { case packetTypePing: // TODO: prevent abuse, as they cause flush()'s. + // https://github.com/tendermint/tendermint/issues/1190 c.Logger.Debug("Receive Ping") select { case c.pong <- struct{}{}: @@ -490,12 +486,7 @@ FOR_LOOP: } case packetTypePong: c.Logger.Debug("Receive Pong") - if c.pongTimer != nil { - if !c.pongTimer.Stop() { - <-c.pongTimer.C - } - drain(c.pongTimeoutCh) - } + c.stopPongTimer() case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) wire.ReadBinaryPtr(&pkt, c.bufReader, c.config.maxMsgPacketTotalSize(), &n, &err) @@ -543,6 +534,15 @@ FOR_LOOP: } } +func (c *MConnection) stopPongTimer() { + if c.pongTimer != nil { + if !c.pongTimer.Stop() { + <-c.pongTimer.C + } + drain(c.pongTimeoutCh) + } +} + type ConnectionStatus struct { Duration time.Duration SendMonitor flow.Status diff --git a/p2p/conn/connection_test.go b/p2p/conn/connection_test.go index acfa8032a..270b4ae92 100644 --- a/p2p/conn/connection_test.go +++ b/p2p/conn/connection_test.go @@ -145,7 +145,7 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { }() <-serverGotPing - pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond + pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond select { case msgBytes := <-receivedCh: t.Fatalf("Expected error, but got %v", msgBytes) @@ -174,7 +174,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { require.Nil(t, err) defer mconn.Stop() - // sending 3 pongs in a row + // sending 3 pongs in a row (abuse) _, err = server.Write([]byte{packetTypePong}) require.Nil(t, err) _, err = server.Write([]byte{packetTypePong}) @@ -184,8 +184,9 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { serverGotPing := make(chan struct{}) go func() { - // read ping - server.Read(make([]byte, 1)) + // read ping (one byte) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) serverGotPing <- struct{}{} // respond with pong _, err = server.Write([]byte{packetTypePong}) @@ -193,7 +194,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { }() <-serverGotPing - pongTimerExpired := mconn.config.PongTimeout + 10*time.Millisecond + pongTimerExpired := mconn.config.PongTimeout + 20*time.Millisecond select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes) @@ -204,6 +205,41 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { } } +func TestMConnectionMultiplePings(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + err := mconn.Start() + require.Nil(t, err) + defer mconn.Stop() + + // sending 3 pings in a row (abuse) + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + _, err = server.Write([]byte{packetTypePing}) + require.Nil(t, err) + _, err = server.Read(make([]byte, 1)) + require.Nil(t, err) + + assert.True(t, mconn.IsRunning()) +} + func TestMConnectionPingPongs(t *testing.T) { server, client := net.Pipe() defer server.Close() @@ -241,7 +277,7 @@ func TestMConnectionPingPongs(t *testing.T) { }() <-serverGotPing - pongTimerExpired := (mconn.config.PongTimeout + 10*time.Millisecond) * 2 + pongTimerExpired := (mconn.config.PongTimeout + 20*time.Millisecond) * 2 select { case msgBytes := <-receivedCh: t.Fatalf("Expected no data, but got %v", msgBytes)