diff --git a/p2p/pex_reactor.go b/p2p/pex_reactor.go index 1903a460f..63862d1a5 100644 --- a/p2p/pex_reactor.go +++ b/p2p/pex_reactor.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/pkg/errors" wire "github.com/tendermint/go-wire" cmn "github.com/tendermint/tmlibs/common" ) @@ -19,10 +20,6 @@ const ( defaultEnsurePeersPeriod = 30 * time.Second minNumOutboundPeers = 10 maxPexMessageSize = 1048576 // 1MB - - // maximum pex messages one peer can send to us during `msgCountByPeerFlushInterval` - defaultMaxMsgCountByPeer = 1000 - msgCountByPeerFlushInterval = 1 * time.Hour ) // PEXReactor handles PEX (peer exchange) and ensures that an @@ -32,15 +29,8 @@ const ( // // ## Preventing abuse // -// For now, it just limits the number of messages from one peer to -// `defaultMaxMsgCountByPeer` messages per `msgCountByPeerFlushInterval` (1000 -// msg/hour). -// -// NOTE [2017-01-17]: -// Limiting is fine for now. Maybe down the road we want to keep track of the -// quality of peer messages so if peerA keeps telling us about peers we can't -// connect to then maybe we should care less about peerA. But I don't think -// that kind of complexity is priority right now. +// Only accept pexAddrsMsg from peers we sent a corresponding pexRequestMsg too. +// Only accept one pexRequestMsg every ~defaultEnsurePeersPeriod. type PEXReactor struct { BaseReactor @@ -48,9 +38,9 @@ type PEXReactor struct { config *PEXReactorConfig ensurePeersPeriod time.Duration - // tracks message count by peer, so we can prevent abuse - msgCountByPeer *cmn.CMap - maxMsgCountByPeer uint16 + // maps to prevent abuse + requestsSent *cmn.CMap // unanswered send requests + lastReceivedRequests *cmn.CMap // last time peer requested from us } // PEXReactorConfig holds reactor specific configuration data. @@ -63,11 +53,11 @@ type PEXReactorConfig struct { // NewPEXReactor creates new PEX reactor. func NewPEXReactor(b *AddrBook, config *PEXReactorConfig) *PEXReactor { r := &PEXReactor{ - book: b, - config: config, - ensurePeersPeriod: defaultEnsurePeersPeriod, - msgCountByPeer: cmn.NewCMap(), - maxMsgCountByPeer: defaultMaxMsgCountByPeer, + book: b, + config: config, + ensurePeersPeriod: defaultEnsurePeersPeriod, + requestsSent: cmn.NewCMap(), + lastReceivedRequests: cmn.NewCMap(), } r.BaseReactor = *NewBaseReactor("PEXReactor", r) return r @@ -83,7 +73,6 @@ func (r *PEXReactor) OnStart() error { return err } go r.ensurePeersRoutine() - go r.flushMsgCountByPeer() return nil } @@ -108,15 +97,17 @@ func (r *PEXReactor) GetChannels() []*ChannelDescriptor { // or by requesting more addresses (if outbound). func (r *PEXReactor) AddPeer(p Peer) { if p.IsOutbound() { - // For outbound peers, the address is already in the books. - // Either it was added in DialPeersAsync or when we - // received the peer's address in r.Receive + // For outbound peers, the address is already in the books - + // either via DialPeersAsync or r.Receive. + // Ask it for more peers if we need. if r.book.NeedMoreAddrs() { r.RequestPEX(p) } } else { - // For inbound connections, the peer is its own source, - // and its NodeInfo has already been validated + // For inbound peers, the peer is its own source, + // and its NodeInfo has already been validated. + // Let the ensurePeersRoutine handle asking for more + // peers when we need - we don't trust inbound peers as much. addr := p.NodeInfo().NetAddress() r.book.AddAddress(addr, addr) } @@ -124,20 +115,13 @@ func (r *PEXReactor) AddPeer(p Peer) { // RemovePeer implements Reactor. func (r *PEXReactor) RemovePeer(p Peer, reason interface{}) { - // If we aren't keeping track of local temp data for each peer here, then we - // don't have to do anything. + id := string(p.ID()) + r.requestsSent.Delete(id) + r.lastReceivedRequests.Delete(id) } // Receive implements Reactor by handling incoming PEX messages. func (r *PEXReactor) Receive(chID byte, src Peer, msgBytes []byte) { - srcAddr := src.NodeInfo().NetAddress() - r.IncrementMsgCountForPeer(srcAddr.ID) - if r.ReachedMaxMsgCountForPeer(srcAddr.ID) { - r.Logger.Error("Maximum number of messages reached for peer", "peer", srcAddr) - // TODO remove src from peers? - return - } - _, msg, err := DecodeMessage(msgBytes) if err != nil { r.Logger.Error("Error decoding message", "err", err) @@ -147,27 +131,81 @@ func (r *PEXReactor) Receive(chID byte, src Peer, msgBytes []byte) { switch msg := msg.(type) { case *pexRequestMessage: - // src requested some peers. - // NOTE: we might send an empty selection + // We received a request for peers from src. + if err := r.receiveRequest(src); err != nil { + r.Switch.StopPeerForError(src, err) + return + } r.SendAddrs(src, r.book.GetSelection()) case *pexAddrsMessage: // We received some peer addresses from src. - // TODO: (We don't want to get spammed with bad peers) - for _, netAddr := range msg.Addrs { - if netAddr != nil { - r.book.AddAddress(netAddr, srcAddr) - } + if err := r.ReceivePEX(msg.Addrs, src); err != nil { + r.Switch.StopPeerForError(src, err) + return } default: r.Logger.Error(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) } } -// RequestPEX asks peer for more addresses. +func (r *PEXReactor) receiveRequest(src Peer) error { + id := string(src.ID()) + v := r.lastReceivedRequests.Get(id) + if v == nil { + // initialize with empty time + lastReceived := time.Time{} + r.lastReceivedRequests.Set(id, lastReceived) + return nil + } + + lastReceived := v.(time.Time) + if lastReceived.Equal(time.Time{}) { + // first time gets a free pass. then we start tracking the time + lastReceived := time.Now() + r.lastReceivedRequests.Set(id, lastReceived) + return nil + } + + now := time.Now() + if now.Sub(lastReceived) < r.ensurePeersPeriod/3 { + return fmt.Errorf("Peer (%v) is sending too many PEX requests. Disconnecting", src.ID()) + } + r.lastReceivedRequests.Set(id, now) + return nil +} + +// RequestPEX asks peer for more addresses if we do not already +// have a request out for this peer. func (r *PEXReactor) RequestPEX(p Peer) { + id := string(p.ID()) + if r.requestsSent.Has(id) { + return + } + r.requestsSent.Set(id, struct{}{}) p.Send(PexChannel, struct{ PexMessage }{&pexRequestMessage{}}) } +// ReceivePEX adds the given addrs to the addrbook if theres an open +// request for this peer and deletes the open request. +// If there's no open request for the src peer, it returns an error. +func (r *PEXReactor) ReceivePEX(addrs []*NetAddress, src Peer) error { + id := string(src.ID()) + + if !r.requestsSent.Has(id) { + return errors.New("Received unsolicited pexAddrsMessage") + } + + r.requestsSent.Delete(id) + + srcAddr := src.NodeInfo().NetAddress() + for _, netAddr := range addrs { + if netAddr != nil { + r.book.AddAddress(netAddr, srcAddr) + } + } + return nil +} + // SendAddrs sends addrs to the peer. func (r *PEXReactor) SendAddrs(p Peer, netAddrs []*NetAddress) { p.Send(PexChannel, struct{ PexMessage }{&pexAddrsMessage{Addrs: netAddrs}}) @@ -178,41 +216,13 @@ func (r *PEXReactor) SetEnsurePeersPeriod(d time.Duration) { r.ensurePeersPeriod = d } -// SetMaxMsgCountByPeer sets maximum messages one peer can send to us during 'msgCountByPeerFlushInterval'. -func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) { - r.maxMsgCountByPeer = v -} - -// ReachedMaxMsgCountForPeer returns true if we received too many -// messages from peer with address `addr`. -// NOTE: assumes the value in the CMap is non-nil -func (r *PEXReactor) ReachedMaxMsgCountForPeer(peerID ID) bool { - return r.msgCountByPeer.Get(string(peerID)).(uint16) >= r.maxMsgCountByPeer -} - -// Increment or initialize the msg count for the peer in the CMap -func (r *PEXReactor) IncrementMsgCountForPeer(peerID ID) { - var count uint16 - countI := r.msgCountByPeer.Get(string(peerID)) - if countI != nil { - count = countI.(uint16) - } - count++ - r.msgCountByPeer.Set(string(peerID), count) -} - // Ensures that sufficient peers are connected. (continuous) func (r *PEXReactor) ensurePeersRoutine() { // Randomize when routine starts ensurePeersPeriodMs := r.ensurePeersPeriod.Nanoseconds() / 1e6 time.Sleep(time.Duration(rand.Int63n(ensurePeersPeriodMs)) * time.Millisecond) - // fire once immediately. - r.ensurePeers() - - // fire periodically ticker := time.NewTicker(r.ensurePeersPeriod) - for { select { case <-ticker.C: @@ -298,20 +308,6 @@ func (r *PEXReactor) ensurePeers() { } } -func (r *PEXReactor) flushMsgCountByPeer() { - ticker := time.NewTicker(msgCountByPeerFlushInterval) - - for { - select { - case <-ticker.C: - r.msgCountByPeer.Clear() - case <-r.Quit: - ticker.Stop() - return - } - } -} - //----------------------------------------------------------------------------- // Messages diff --git a/p2p/pex_reactor_test.go b/p2p/pex_reactor_test.go index 20c8b823a..0c1c17330 100644 --- a/p2p/pex_reactor_test.go +++ b/p2p/pex_reactor_test.go @@ -153,9 +153,11 @@ func TestPEXReactorReceive(t *testing.T) { peer := createRandomPeer(false) + // we have to send a request to receive responses + r.RequestPEX(peer) + size := book.Size() - netAddr, _ := NewNetAddressString(peer.NodeInfo().ListenAddr) - addrs := []*NetAddress{netAddr} + addrs := []*NetAddress{peer.NodeInfo().NetAddress()} msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) r.Receive(PexChannel, peer, msg) assert.Equal(size+1, book.Size()) @@ -164,7 +166,7 @@ func TestPEXReactorReceive(t *testing.T) { r.Receive(PexChannel, peer, msg) } -func TestPEXReactorAbuseFromPeer(t *testing.T) { +func TestPEXReactorRequestMessageAbuse(t *testing.T) { assert, require := assert.New(t), require.New(t) dir, err := ioutil.TempDir("", "pex_reactor") @@ -174,17 +176,66 @@ func TestPEXReactorAbuseFromPeer(t *testing.T) { book.SetLogger(log.TestingLogger()) r := NewPEXReactor(book, &PEXReactorConfig{}) + sw := makeSwitch(config, 0, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch { return sw }) + sw.SetLogger(log.TestingLogger()) + sw.AddReactor("PEX", r) + r.SetSwitch(sw) r.SetLogger(log.TestingLogger()) - r.SetMaxMsgCountByPeer(5) - peer := createRandomPeer(false) + peer := newMockPeer() + id := string(peer.ID()) msg := wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) - for i := 0; i < 10; i++ { - r.Receive(PexChannel, peer, msg) - } - assert.True(r.ReachedMaxMsgCountForPeer(peer.NodeInfo().ID())) + // first time creates the entry + r.Receive(PexChannel, peer, msg) + assert.True(r.lastReceivedRequests.Has(id)) + + // next time sets the last time value + r.Receive(PexChannel, peer, msg) + assert.True(r.lastReceivedRequests.Has(id)) + + // third time is too many too soon - peer is removed + r.Receive(PexChannel, peer, msg) + assert.False(r.lastReceivedRequests.Has(id)) + assert.False(sw.Peers().Has(peer.ID())) + +} + +func TestPEXReactorAddrsMessageAbuse(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) // nolint: errcheck + book := NewAddrBook(dir+"addrbook.json", true) + book.SetLogger(log.TestingLogger()) + + r := NewPEXReactor(book, &PEXReactorConfig{}) + sw := makeSwitch(config, 0, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch { return sw }) + sw.SetLogger(log.TestingLogger()) + sw.AddReactor("PEX", r) + r.SetSwitch(sw) + r.SetLogger(log.TestingLogger()) + + peer := newMockPeer() + + id := string(peer.ID()) + + // request addrs from the peer + r.RequestPEX(peer) + assert.True(r.requestsSent.Has(id)) + + addrs := []*NetAddress{peer.NodeInfo().NetAddress()} + msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) + + // receive some addrs. should clear the request + r.Receive(PexChannel, peer, msg) + assert.False(r.requestsSent.Has(id)) + + // receiving more addrs causes a disconnect + r.Receive(PexChannel, peer, msg) + assert.False(sw.Peers().Has(peer.ID())) } func TestPEXReactorUsesSeedsIfNeeded(t *testing.T) { @@ -252,3 +303,36 @@ func createRandomPeer(outbound bool) *peer { p.SetLogger(log.TestingLogger().With("peer", addr)) return p } + +type mockPeer struct { + *cmn.BaseService + pubKey crypto.PubKey + addr *NetAddress + outbound, persistent bool +} + +func newMockPeer() mockPeer { + _, netAddr := createRoutableAddr() + mp := mockPeer{ + addr: netAddr, + pubKey: crypto.GenPrivKeyEd25519().Wrap().PubKey(), + } + mp.BaseService = cmn.NewBaseService(nil, "MockPeer", mp) + mp.Start() + return mp +} + +func (mp mockPeer) ID() ID { return PubKeyToID(mp.pubKey) } +func (mp mockPeer) IsOutbound() bool { return mp.outbound } +func (mp mockPeer) IsPersistent() bool { return mp.persistent } +func (mp mockPeer) NodeInfo() NodeInfo { + return NodeInfo{ + PubKey: mp.pubKey, + ListenAddr: mp.addr.DialString(), + } +} +func (mp mockPeer) Status() ConnectionStatus { return ConnectionStatus{} } +func (mp mockPeer) Send(byte, interface{}) bool { return false } +func (mp mockPeer) TrySend(byte, interface{}) bool { return false } +func (mp mockPeer) Set(string, interface{}) {} +func (mp mockPeer) Get(string) interface{} { return nil }