diff --git a/blocksync/msgs.go b/blocksync/msgs.go index e3d6e551c..fe1c87d7f 100644 --- a/blocksync/msgs.go +++ b/blocksync/msgs.go @@ -19,8 +19,7 @@ const ( BlockResponseMessageFieldKeySize ) -// EncodeMsg encodes a Protobuf message -func EncodeMsg(pb proto.Message) ([]byte, error) { +func wrapMsg(pb proto.Message) (proto.Message, error) { msg := bcproto.Message{} switch pb := pb.(type) { @@ -38,12 +37,7 @@ func EncodeMsg(pb proto.Message) ([]byte, error) { return nil, fmt.Errorf("unknown message type %T", pb) } - bz, err := proto.Marshal(&msg) - if err != nil { - return nil, fmt.Errorf("unable to marshal %T: %w", pb, err) - } - - return bz, nil + return &msg, nil } // DecodeMsg decodes a Protobuf message. @@ -54,7 +48,10 @@ func DecodeMsg(bz []byte) (proto.Message, error) { if err != nil { return nil, err } + return UnwrapMessage(pb) +} +func UnwrapMessage(pb *bcproto.Message) (proto.Message, error) { switch msg := pb.Sum.(type) { case *bcproto.Message_BlockRequest: return msg.BlockRequest, nil diff --git a/blocksync/reactor.go b/blocksync/reactor.go index 09dd2ef90..9bac4229b 100644 --- a/blocksync/reactor.go +++ b/blocksync/reactor.go @@ -143,21 +143,26 @@ func (bcR *Reactor) GetChannels() []*p2p.ChannelDescriptor { SendQueueCapacity: 1000, RecvBufferCapacity: 50 * 4096, RecvMessageCapacity: MaxMsgSize, + MessageType: &bcproto.Message{}, }, } } // AddPeer implements Reactor by sending our state to peer. func (bcR *Reactor) AddPeer(peer p2p.Peer) { - msgBytes, err := EncodeMsg(&bcproto.StatusResponse{ + msg, err := wrapMsg(&bcproto.StatusResponse{ Base: bcR.store.Base(), - Height: bcR.store.Height()}) + Height: bcR.store.Height(), + }) if err != nil { bcR.Logger.Error("could not convert msg to protobuf", "err", err) return } - peer.Send(BlocksyncChannel, msgBytes) + peer.Send(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: msg, + }) // it's OK if send fails. will try later in poolRoutine // peer is added to the pool once we receive the first @@ -182,69 +187,79 @@ func (bcR *Reactor) respondToPeer(msg *bcproto.BlockRequest, return false } - msgBytes, err := EncodeMsg(&bcproto.BlockResponse{Block: bl}) + wm, err := wrapMsg(&bcproto.BlockResponse{Block: bl}) if err != nil { - bcR.Logger.Error("could not marshal msg", "err", err) + bcR.Logger.Error("could not convert msg to proto message", "err", err) return false } - return src.TrySend(BlocksyncChannel, msgBytes) + return src.TrySend(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: wm, + }) } bcR.Logger.Info("Peer asking for a block we don't have", "src", src, "height", msg.Height) - msgBytes, err := EncodeMsg(&bcproto.NoBlockResponse{Height: msg.Height}) + wm, err := wrapMsg(&bcproto.NoBlockResponse{Height: msg.Height}) if err != nil { bcR.Logger.Error("could not convert msg to protobuf", "err", err) return false } - return src.TrySend(BlocksyncChannel, msgBytes) + return src.TrySend(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: wm, + }) } // Receive implements Reactor by handling 4 types of messages (look below). -func (bcR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { - msg, err := DecodeMsg(msgBytes) +func (bcR *Reactor) Receive(e p2p.Envelope) { + msg, err := UnwrapMessage(e.Message.(*bcproto.Message)) if err != nil { - bcR.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - bcR.Switch.StopPeerForError(src, err) + bcR.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + bcR.Switch.StopPeerForError(e.Src, err) return } if err = ValidateMsg(msg); err != nil { - bcR.Logger.Error("Peer sent us invalid msg", "peer", src, "msg", msg, "err", err) - bcR.Switch.StopPeerForError(src, err) + bcR.Logger.Error("Peer sent us invalid msg", "peer", e.Src, "msg", msg, "err", err) + bcR.Switch.StopPeerForError(e.Src, err) return } - bcR.Logger.Debug("Receive", "src", src, "chID", chID, "msg", msg) + bcR.Logger.Debug("Receive", "e.Src", e.Src, "chID", e.ChannelID, "msg", msg) switch msg := msg.(type) { case *bcproto.BlockRequest: - bcR.respondToPeer(msg, src) + bcR.respondToPeer(msg, e.Src) case *bcproto.BlockResponse: bi, err := types.BlockFromProto(msg.Block) if err != nil { bcR.Logger.Error("Block content is invalid", "err", err) return } - bcR.pool.AddBlock(src.ID(), bi, len(msgBytes)) + bcR.pool.AddBlock(e.Src.ID(), bi, msg.Block.Size()) case *bcproto.StatusRequest: // Send peer our state. - msgBytes, err := EncodeMsg(&bcproto.StatusResponse{ + wm, err := wrapMsg(&bcproto.StatusResponse{ Height: bcR.store.Height(), Base: bcR.store.Base(), }) if err != nil { - bcR.Logger.Error("could not convert msg to protobut", "err", err) + bcR.Logger.Error("could not convert msg to proto message", "err", err) return } - src.TrySend(BlocksyncChannel, msgBytes) + + e.Src.TrySend(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: wm, + }) case *bcproto.StatusResponse: // Got a peer status. Unverified. - bcR.pool.SetPeerRange(src.ID(), msg.Base, msg.Height) + bcR.pool.SetPeerRange(e.Src.ID(), msg.Base, msg.Height) case *bcproto.NoBlockResponse: - bcR.Logger.Debug("Peer does not have requested block", "peer", src, "height", msg.Height) + bcR.Logger.Debug("Peer does not have requested block", "peer", e.Src, "height", msg.Height) default: bcR.Logger.Error(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) } @@ -285,13 +300,15 @@ func (bcR *Reactor) poolRoutine(stateSynced bool) { if peer == nil { continue } - msgBytes, err := EncodeMsg(&bcproto.BlockRequest{Height: request.Height}) + wm, err := wrapMsg(&bcproto.BlockRequest{Height: request.Height}) if err != nil { bcR.Logger.Error("could not convert msg to proto", "err", err) continue } - - queued := peer.TrySend(BlocksyncChannel, msgBytes) + queued := peer.TrySend(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: wm, + }) if !queued { bcR.Logger.Debug("Send queue is full, drop block request", "peer", peer.ID(), "height", request.Height) } @@ -430,13 +447,15 @@ FOR_LOOP: // BroadcastStatusRequest broadcasts `BlockStore` base and height. func (bcR *Reactor) BroadcastStatusRequest() error { - bm, err := EncodeMsg(&bcproto.StatusRequest{}) + wm, err := wrapMsg(&bcproto.StatusRequest{}) if err != nil { - bcR.Logger.Error("could not convert msg to proto", "err", err) - return fmt.Errorf("could not convert msg to proto: %w", err) + bcR.Logger.Error("could not convert msg to proto message", "err", err) + return fmt.Errorf("could not convert msg to proto message: %w", err) } - - bcR.Switch.Broadcast(BlocksyncChannel, bm) + bcR.Switch.NewBroadcast(p2p.Envelope{ + ChannelID: BlocksyncChannel, + Message: wm, + }) return nil } diff --git a/consensus/byzantine_test.go b/consensus/byzantine_test.go index fe0c36a14..92ac8889e 100644 --- a/consensus/byzantine_test.go +++ b/consensus/byzantine_test.go @@ -165,10 +165,16 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { for i, peer := range peerList { if i < len(peerList)/2 { bcs.Logger.Info("Signed and pushed vote", "vote", prevote1, "peer", peer) - peer.Send(VoteChannel, MustEncode(&VoteMessage{prevote1})) + peer.Send(p2p.Envelope{ + Message: MustMsgToProto(&VoteMessage{prevote1}), + ChannelID: VoteChannel, + }) } else { bcs.Logger.Info("Signed and pushed vote", "vote", prevote2, "peer", peer) - peer.Send(VoteChannel, MustEncode(&VoteMessage{prevote2})) + peer.Send(p2p.Envelope{ + Message: MustMsgToProto(&VoteMessage{prevote2}), + ChannelID: VoteChannel, + }) } } } else { @@ -521,7 +527,10 @@ func sendProposalAndParts( ) { // proposal msg := &ProposalMessage{Proposal: proposal} - peer.Send(DataChannel, MustEncode(msg)) + peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) // parts for i := 0; i < int(parts.Total()); i++ { @@ -531,7 +540,10 @@ func sendProposalAndParts( Round: round, // This tells peer that this part applies to us. Part: part, } - peer.Send(DataChannel, MustEncode(msg)) + peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) } // votes @@ -539,9 +551,14 @@ func sendProposalAndParts( prevote, _ := cs.signVote(tmproto.PrevoteType, blockHash, parts.Header()) precommit, _ := cs.signVote(tmproto.PrecommitType, blockHash, parts.Header()) cs.mtx.Unlock() - - peer.Send(VoteChannel, MustEncode(&VoteMessage{prevote})) - peer.Send(VoteChannel, MustEncode(&VoteMessage{precommit})) + peer.Send(p2p.Envelope{ + ChannelID: VoteChannel, + Message: MustMsgToProto(&VoteMessage{prevote}), + }) + peer.Send(p2p.Envelope{ + ChannelID: VoteChannel, + Message: MustMsgToProto(&VoteMessage{precommit}), + }) } //---------------------------------------- @@ -579,7 +596,7 @@ func (br *ByzantineReactor) AddPeer(peer p2p.Peer) { func (br *ByzantineReactor) RemovePeer(peer p2p.Peer, reason interface{}) { br.reactor.RemovePeer(peer, reason) } -func (br *ByzantineReactor) Receive(chID byte, peer p2p.Peer, msgBytes []byte) { - br.reactor.Receive(chID, peer, msgBytes) +func (br *ByzantineReactor) Receive(e p2p.Envelope) { + br.reactor.Receive(e) } func (br *ByzantineReactor) InitPeer(peer p2p.Peer) p2p.Peer { return peer } diff --git a/consensus/invalid_test.go b/consensus/invalid_test.go index f96018157..8fef4e8b2 100644 --- a/consensus/invalid_test.go +++ b/consensus/invalid_test.go @@ -94,7 +94,10 @@ func invalidDoPrevoteFunc(t *testing.T, height int64, round int32, cs *State, sw peers := sw.Peers().List() for _, peer := range peers { cs.Logger.Info("Sending bad vote", "block", blockHash, "peer", peer) - peer.Send(VoteChannel, MustEncode(&VoteMessage{precommit})) + peer.Send(p2p.Envelope{ + Message: MustMsgToProto(&VoteMessage{precommit}), + ChannelID: VoteChannel, + }) } }() } diff --git a/consensus/msgs.go b/consensus/msgs.go index 5d22905cd..6aaacba83 100644 --- a/consensus/msgs.go +++ b/consensus/msgs.go @@ -15,7 +15,7 @@ import ( "github.com/tendermint/tendermint/types" ) -// MsgToProto takes a consensus message type and returns the proto defined consensus message +// MsgToProto takes a consensus message type and returns the proto defined consensus message. func MsgToProto(msg Message) (*tmcons.Message, error) { if msg == nil { return nil, errors.New("consensus: message is nil") @@ -143,6 +143,14 @@ func MsgToProto(msg Message) (*tmcons.Message, error) { return &pb, nil } +func MustMsgToProto(msg Message) *tmcons.Message { + m, err := MsgToProto(msg) + if err != nil { + panic(err) + } + return m +} + // MsgFromProto takes a consensus proto message and returns the native go type func MsgFromProto(msg *tmcons.Message) (Message, error) { if msg == nil { diff --git a/consensus/reactor.go b/consensus/reactor.go index b0d3e3675..d2b1dad6b 100644 --- a/consensus/reactor.go +++ b/consensus/reactor.go @@ -148,6 +148,7 @@ func (conR *Reactor) GetChannels() []*p2p.ChannelDescriptor { Priority: 6, SendQueueCapacity: 100, RecvMessageCapacity: maxMsgSize, + MessageType: &tmcons.Message{}, }, { ID: DataChannel, // maybe split between gossiping current block and catchup stuff @@ -156,6 +157,7 @@ func (conR *Reactor) GetChannels() []*p2p.ChannelDescriptor { SendQueueCapacity: 100, RecvBufferCapacity: 50 * 4096, RecvMessageCapacity: maxMsgSize, + MessageType: &tmcons.Message{}, }, { ID: VoteChannel, @@ -163,6 +165,7 @@ func (conR *Reactor) GetChannels() []*p2p.ChannelDescriptor { SendQueueCapacity: 100, RecvBufferCapacity: 100 * 100, RecvMessageCapacity: maxMsgSize, + MessageType: &tmcons.Message{}, }, { ID: VoteSetBitsChannel, @@ -170,6 +173,7 @@ func (conR *Reactor) GetChannels() []*p2p.ChannelDescriptor { SendQueueCapacity: 2, RecvBufferCapacity: 1024, RecvMessageCapacity: maxMsgSize, + MessageType: &tmcons.Message{}, }, } } @@ -223,34 +227,34 @@ func (conR *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { // Peer state updates can happen in parallel, but processing of // proposals, block parts, and votes are ordered by the receiveRoutine // NOTE: blocks on consensus state for proposals, block parts, and votes -func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { +func (conR *Reactor) Receive(e p2p.Envelope) { if !conR.IsRunning() { - conR.Logger.Debug("Receive", "src", src, "chId", chID, "bytes", msgBytes) + conR.Logger.Debug("Receive", "src", e.Src, "chId", e.ChannelID) return } - msg, err := decodeMsg(msgBytes) + msg, err := MsgFromProto(e.Message.(*tmcons.Message)) if err != nil { - conR.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - conR.Switch.StopPeerForError(src, err) + conR.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + conR.Switch.StopPeerForError(e.Src, err) return } if err = msg.ValidateBasic(); err != nil { - conR.Logger.Error("Peer sent us invalid msg", "peer", src, "msg", msg, "err", err) - conR.Switch.StopPeerForError(src, err) + conR.Logger.Error("Peer sent us invalid msg", "peer", e.Src, "msg", e.Message, "err", err) + conR.Switch.StopPeerForError(e.Src, err) return } - conR.Logger.Debug("Receive", "src", src, "chId", chID, "msg", msg) + conR.Logger.Debug("Receive", "src", e.Src, "chId", e.ChannelID, "msg", msg) // Get peer states - ps, ok := src.Get(types.PeerStateKey).(*PeerState) + ps, ok := e.Src.Get(types.PeerStateKey).(*PeerState) if !ok { - panic(fmt.Sprintf("Peer %v has no state", src)) + panic(fmt.Sprintf("Peer %v has no state", e.Src)) } - switch chID { + switch e.ChannelID { case StateChannel: switch msg := msg.(type) { case *NewRoundStepMessage: @@ -258,8 +262,8 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { initialHeight := conR.conS.state.InitialHeight conR.conS.mtx.Unlock() if err = msg.ValidateHeight(initialHeight); err != nil { - conR.Logger.Error("Peer sent us invalid msg", "peer", src, "msg", msg, "err", err) - conR.Switch.StopPeerForError(src, err) + conR.Logger.Error("Peer sent us invalid msg", "peer", e.Src, "msg", msg, "err", err) + conR.Switch.StopPeerForError(e.Src, err) return } ps.ApplyNewRoundStepMessage(msg) @@ -278,7 +282,7 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { // Peer claims to have a maj23 for some BlockID at H,R,S, err := votes.SetPeerMaj23(msg.Round, msg.Type, ps.peer.ID(), msg.BlockID) if err != nil { - conR.Switch.StopPeerForError(src, err) + conR.Switch.StopPeerForError(e.Src, err) return } // Respond with a VoteSetBitsMessage showing which votes we have. @@ -292,13 +296,16 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { default: panic("Bad VoteSetBitsMessage field Type. Forgot to add a check in ValidateBasic?") } - src.TrySend(VoteSetBitsChannel, MustEncode(&VoteSetBitsMessage{ - Height: msg.Height, - Round: msg.Round, - Type: msg.Type, - BlockID: msg.BlockID, - Votes: ourVotes, - })) + e.Src.TrySend(p2p.Envelope{ + ChannelID: VoteSetBitsChannel, + Message: MustMsgToProto(&VoteSetBitsMessage{ + Height: msg.Height, + Round: msg.Round, + Type: msg.Type, + BlockID: msg.BlockID, + Votes: ourVotes, + }), + }) default: conR.Logger.Error(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) } @@ -311,13 +318,13 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { switch msg := msg.(type) { case *ProposalMessage: ps.SetHasProposal(msg.Proposal) - conR.conS.peerMsgQueue <- msgInfo{msg, src.ID()} + conR.conS.peerMsgQueue <- msgInfo{msg, e.Src.ID()} case *ProposalPOLMessage: ps.ApplyProposalPOLMessage(msg) case *BlockPartMessage: ps.SetHasProposalBlockPart(msg.Height, msg.Round, int(msg.Part.Index)) - conR.Metrics.BlockParts.With("peer_id", string(src.ID())).Add(1) - conR.conS.peerMsgQueue <- msgInfo{msg, src.ID()} + conR.Metrics.BlockParts.With("peer_id", string(e.Src.ID())).Add(1) + conR.conS.peerMsgQueue <- msgInfo{msg, e.Src.ID()} default: conR.Logger.Error(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) } @@ -337,7 +344,7 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { ps.EnsureVoteBitArrays(height-1, lastCommitSize) ps.SetHasVote(msg.Vote) - cs.peerMsgQueue <- msgInfo{msg, src.ID()} + cs.peerMsgQueue <- msgInfo{msg, e.Src.ID()} default: // don't punish (leave room for soft upgrades) @@ -376,7 +383,7 @@ func (conR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { } default: - conR.Logger.Error(fmt.Sprintf("Unknown chId %X", chID)) + conR.Logger.Error(fmt.Sprintf("Unknown chId %X", e.ChannelID)) } } @@ -430,7 +437,10 @@ func (conR *Reactor) unsubscribeFromBroadcastEvents() { func (conR *Reactor) broadcastNewRoundStepMessage(rs *cstypes.RoundState) { nrsMsg := makeRoundStepMessage(rs) - conR.Switch.Broadcast(StateChannel, MustEncode(nrsMsg)) + conR.Switch.NewBroadcast(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(nrsMsg), + }) } func (conR *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) { @@ -441,7 +451,11 @@ func (conR *Reactor) broadcastNewValidBlockMessage(rs *cstypes.RoundState) { BlockParts: rs.ProposalBlockParts.BitArray(), IsCommit: rs.Step == cstypes.RoundStepCommit, } - conR.Switch.Broadcast(StateChannel, MustEncode(csMsg)) + MsgToProto(csMsg) + conR.Switch.NewBroadcast(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(csMsg), + }) } // Broadcasts HasVoteMessage to peers that care. @@ -452,7 +466,10 @@ func (conR *Reactor) broadcastHasVoteMessage(vote *types.Vote) { Type: vote.Type, Index: vote.ValidatorIndex, } - conR.Switch.Broadcast(StateChannel, MustEncode(msg)) + conR.Switch.NewBroadcast(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(msg), + }) /* // TODO: Make this broadcast more selective. for _, peer := range conR.Switch.Peers().List() { @@ -463,7 +480,11 @@ func (conR *Reactor) broadcastHasVoteMessage(vote *types.Vote) { prs := ps.GetRoundState() if prs.Height == vote.Height { // TODO: Also filter on round? - peer.TrySend(StateChannel, struct{ ConsensusMessage }{msg}) + e := p2p.Envelope{ + ChannelID: StateChannel, struct{ ConsensusMessage }{msg}, + Message: p, + } + peer.TrySend(e) } else { // Height doesn't match // TODO: check a field, maybe CatchupCommitRound? @@ -487,7 +508,10 @@ func makeRoundStepMessage(rs *cstypes.RoundState) (nrsMsg *NewRoundStepMessage) func (conR *Reactor) sendNewRoundStepMessage(peer p2p.Peer) { rs := conR.getRoundState() nrsMsg := makeRoundStepMessage(rs) - peer.Send(StateChannel, MustEncode(nrsMsg)) + peer.Send(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(nrsMsg), + }) } func (conR *Reactor) updateRoundStateRoutine() { @@ -532,7 +556,10 @@ OUTER_LOOP: Part: part, } logger.Debug("Sending block part", "height", prs.Height, "round", prs.Round) - if peer.Send(DataChannel, MustEncode(msg)) { + if peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) { ps.SetHasProposalBlockPart(prs.Height, prs.Round, index) } continue OUTER_LOOP @@ -580,7 +607,10 @@ OUTER_LOOP: { msg := &ProposalMessage{Proposal: rs.Proposal} logger.Debug("Sending proposal", "height", prs.Height, "round", prs.Round) - if peer.Send(DataChannel, MustEncode(msg)) { + if peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) { // NOTE[ZM]: A peer might have received different proposal msg so this Proposal msg will be rejected! ps.SetHasProposal(rs.Proposal) } @@ -596,7 +626,10 @@ OUTER_LOOP: ProposalPOL: rs.Votes.Prevotes(rs.Proposal.POLRound).BitArray(), } logger.Debug("Sending POL", "height", prs.Height, "round", prs.Round) - peer.Send(DataChannel, MustEncode(msg)) + peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) } continue OUTER_LOOP } @@ -639,7 +672,10 @@ func (conR *Reactor) gossipDataForCatchup(logger log.Logger, rs *cstypes.RoundSt Part: part, } logger.Debug("Sending block part for catchup", "round", prs.Round, "index", index) - if peer.Send(DataChannel, MustEncode(msg)) { + if peer.Send(p2p.Envelope{ + ChannelID: DataChannel, + Message: MustMsgToProto(msg), + }) { ps.SetHasProposalBlockPart(prs.Height, prs.Round, index) } else { logger.Debug("Sending block part for catchup failed") @@ -798,12 +834,16 @@ OUTER_LOOP: prs := ps.GetRoundState() if rs.Height == prs.Height { if maj23, ok := rs.Votes.Prevotes(prs.Round).TwoThirdsMajority(); ok { - peer.TrySend(StateChannel, MustEncode(&VoteSetMaj23Message{ - Height: prs.Height, - Round: prs.Round, - Type: tmproto.PrevoteType, - BlockID: maj23, - })) + + peer.TrySend(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(&VoteSetMaj23Message{ + Height: prs.Height, + Round: prs.Round, + Type: tmproto.PrevoteType, + BlockID: maj23, + }), + }) time.Sleep(conR.conS.config.PeerQueryMaj23SleepDuration) } } @@ -815,12 +855,16 @@ OUTER_LOOP: prs := ps.GetRoundState() if rs.Height == prs.Height { if maj23, ok := rs.Votes.Precommits(prs.Round).TwoThirdsMajority(); ok { - peer.TrySend(StateChannel, MustEncode(&VoteSetMaj23Message{ - Height: prs.Height, - Round: prs.Round, - Type: tmproto.PrecommitType, - BlockID: maj23, - })) + + peer.TrySend(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(&VoteSetMaj23Message{ + Height: prs.Height, + Round: prs.Round, + Type: tmproto.PrecommitType, + BlockID: maj23, + }), + }) time.Sleep(conR.conS.config.PeerQueryMaj23SleepDuration) } } @@ -832,12 +876,16 @@ OUTER_LOOP: prs := ps.GetRoundState() if rs.Height == prs.Height && prs.ProposalPOLRound >= 0 { if maj23, ok := rs.Votes.Prevotes(prs.ProposalPOLRound).TwoThirdsMajority(); ok { - peer.TrySend(StateChannel, MustEncode(&VoteSetMaj23Message{ - Height: prs.Height, - Round: prs.ProposalPOLRound, - Type: tmproto.PrevoteType, - BlockID: maj23, - })) + + peer.TrySend(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(&VoteSetMaj23Message{ + Height: prs.Height, + Round: prs.ProposalPOLRound, + Type: tmproto.PrevoteType, + BlockID: maj23, + }), + }) time.Sleep(conR.conS.config.PeerQueryMaj23SleepDuration) } } @@ -852,12 +900,15 @@ OUTER_LOOP: if prs.CatchupCommitRound != -1 && prs.Height > 0 && prs.Height <= conR.conS.blockStore.Height() && prs.Height >= conR.conS.blockStore.Base() { if commit := conR.conS.LoadCommit(prs.Height); commit != nil { - peer.TrySend(StateChannel, MustEncode(&VoteSetMaj23Message{ - Height: prs.Height, - Round: commit.Round, - Type: tmproto.PrecommitType, - BlockID: commit.BlockID, - })) + peer.TrySend(p2p.Envelope{ + ChannelID: StateChannel, + Message: MustMsgToProto(&VoteSetMaj23Message{ + Height: prs.Height, + Round: commit.Round, + Type: tmproto.PrecommitType, + BlockID: commit.BlockID, + }), + }) time.Sleep(conR.conS.config.PeerQueryMaj23SleepDuration) } } @@ -1073,7 +1124,10 @@ func (ps *PeerState) PickSendVote(votes types.VoteSetReader) bool { if vote, ok := ps.PickVoteToSend(votes); ok { msg := &VoteMessage{vote} ps.logger.Debug("Sending vote message", "ps", ps, "vote", vote) - if ps.peer.Send(VoteChannel, MustEncode(msg)) { + if ps.peer.Send(p2p.Envelope{ + ChannelID: VoteChannel, + Message: MustMsgToProto(msg), + }) { ps.SetHasVote(vote) return true } diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index 303f5e6e2..3d52e95fd 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -297,6 +297,12 @@ func TestReactorReceivePanicsIfInitPeerHasntBeenCalledYet(t *testing.T) { // simulate switch calling Receive before AddPeer assert.Panics(t, func() { reactor.Receive(StateChannel, peer, msg) + reactor.NewReceive(p2p.Envelope{ + ChannelID: StateChannel, + Src: peer, + Message: MustMsgToProto(&HasVoteMessage{Height: 1, + Round: 1, Index: 1, Type: tmproto.PrevoteType}), + }) }) } diff --git a/evidence/reactor.go b/evidence/reactor.go index 2a136dbfb..c82a713d3 100644 --- a/evidence/reactor.go +++ b/evidence/reactor.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/cosmos/gogoproto/proto" clist "github.com/tendermint/tendermint/libs/clist" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/p2p" @@ -55,6 +56,7 @@ func (evR *Reactor) GetChannels() []*p2p.ChannelDescriptor { ID: EvidenceChannel, Priority: 6, RecvMessageCapacity: maxMsgSize, + MessageType: &tmproto.EvidenceList{}, }, } } @@ -66,11 +68,11 @@ func (evR *Reactor) AddPeer(peer p2p.Peer) { // Receive implements Reactor. // It adds any received evidence to the evpool. -func (evR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { - evis, err := decodeMsg(msgBytes) +func (evR *Reactor) Receive(e p2p.Envelope) { + evis, err := evidenceListFromProto(e.Message) if err != nil { - evR.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - evR.Switch.StopPeerForError(src, err) + evR.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + evR.Switch.StopPeerForError(e.Src, err) return } @@ -80,7 +82,7 @@ func (evR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { case *types.ErrInvalidEvidence: evR.Logger.Error(err.Error()) // punish peer - evR.Switch.StopPeerForError(src, err) + evR.Switch.StopPeerForError(e.Src, err) return case nil: default: @@ -126,11 +128,15 @@ func (evR *Reactor) broadcastEvidenceRoutine(peer p2p.Peer) { evis := evR.prepareEvidenceMessage(peer, ev) if len(evis) > 0 { evR.Logger.Debug("Gossiping evidence to peer", "ev", ev, "peer", peer) - msgBytes, err := encodeMsg(evis) + evp, err := evidenceListToProto(evis) if err != nil { panic(err) } - success := peer.Send(EvidenceChannel, msgBytes) + + success := peer.Send(p2p.Envelope{ + ChannelID: EvidenceChannel, + Message: evp, + }) if !success { time.Sleep(peerRetryMessageIntervalMS * time.Millisecond) continue @@ -226,6 +232,23 @@ func encodeMsg(evis []types.Evidence) ([]byte, error) { return epl.Marshal() } +// encodemsg takes a array of evidence +// returns the byte encoding of the List Message +func evidenceListToProto(evis []types.Evidence) (*tmproto.EvidenceList, error) { + evi := make([]tmproto.Evidence, len(evis)) + for i := 0; i < len(evis); i++ { + ev, err := types.EvidenceToProto(evis[i]) + if err != nil { + return nil, err + } + evi[i] = *ev + } + epl := tmproto.EvidenceList{ + Evidence: evi, + } + return &epl, nil +} + // decodemsg takes an array of bytes // returns an array of evidence func decodeMsg(bz []byte) (evis []types.Evidence, err error) { @@ -251,3 +274,23 @@ func decodeMsg(bz []byte) (evis []types.Evidence, err error) { return evis, nil } +func evidenceListFromProto(m proto.Message) ([]types.Evidence, error) { + lm := m.(*tmproto.EvidenceList) + + evis := make([]types.Evidence, len(lm.Evidence)) + for i := 0; i < len(lm.Evidence); i++ { + ev, err := types.EvidenceFromProto(&lm.Evidence[i]) + if err != nil { + return nil, err + } + evis[i] = ev + } + + for i, ev := range evis { + if err := ev.ValidateBasic(); err != nil { + return nil, fmt.Errorf("invalid evidence (#%d): %v", i, err) + } + } + + return evis, nil +} diff --git a/evidence/reactor_test.go b/evidence/reactor_test.go index a2d82bf71..0d7d1110d 100644 --- a/evidence/reactor_test.go +++ b/evidence/reactor_test.go @@ -208,7 +208,10 @@ func TestReactorBroadcastEvidenceMemoryLeak(t *testing.T) { // i.e. broadcastEvidenceRoutine finishes when peer is stopped defer leaktest.CheckTimeout(t, 10*time.Second)() - p.On("Send", evidence.EvidenceChannel, mock.AnythingOfType("[]uint8")).Return(false) + p.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + return ok && e.ChannelID == evidence.EvidenceChannel + })).Return(false) quitChan := make(<-chan struct{}) p.On("Quit").Return(quitChan) ps := peerState{2} diff --git a/mempool/v0/reactor.go b/mempool/v0/reactor.go index 3fc850641..07d3cd1b6 100644 --- a/mempool/v0/reactor.go +++ b/mempool/v0/reactor.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/cosmos/gogoproto/proto" cfg "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/libs/clist" "github.com/tendermint/tendermint/libs/log" @@ -134,6 +135,7 @@ func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { ID: mempool.MempoolChannel, Priority: 5, RecvMessageCapacity: batchMsg.Size(), + MessageType: &protomem.Message{}, }, } } @@ -154,18 +156,18 @@ func (memR *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { // Receive implements Reactor. // It adds any received transactions to the mempool. -func (memR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { - msg, err := memR.decodeMsg(msgBytes) +func (memR *Reactor) Receive(e p2p.Envelope) { + msg, err := msgFromProto(e.Message) if err != nil { - memR.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - memR.Switch.StopPeerForError(src, err) + memR.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + memR.Switch.StopPeerForError(e.Src, err) return } - memR.Logger.Debug("Receive", "src", src, "chId", chID, "msg", msg) + memR.Logger.Debug("Receive", "src", e.Src, "chId", e.ChannelID, "msg", msg) - txInfo := mempool.TxInfo{SenderID: memR.ids.GetForPeer(src)} - if src != nil { - txInfo.SenderP2PID = src.ID() + txInfo := mempool.TxInfo{SenderID: memR.ids.GetForPeer(e.Src)} + if e.Src != nil { + txInfo.SenderP2PID = e.Src.ID() } for _, tx := range msg.Txs { @@ -234,18 +236,14 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { // https://github.com/tendermint/tendermint/issues/5796 if _, ok := memTx.senders.Load(peerID); !ok { - msg := protomem.Message{ - Sum: &protomem.Message_Txs{ - Txs: &protomem.Txs{Txs: [][]byte{memTx.tx}}, + success := peer.Send(p2p.Envelope{ + ChannelID: mempool.MempoolChannel, + Message: &protomem.Message{ + Sum: &protomem.Message_Txs{ + Txs: &protomem.Txs{Txs: [][]byte{memTx.tx}}, + }, }, - } - - bz, err := msg.Marshal() - if err != nil { - panic(err) - } - - success := peer.Send(mempool.MempoolChannel, bz) + }) if !success { time.Sleep(mempool.PeerCatchupSleepIntervalMS * time.Millisecond) continue @@ -264,15 +262,19 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { } } -func (memR *Reactor) decodeMsg(bz []byte) (TxsMessage, error) { +func decodeMsg(bz []byte) (TxsMessage, error) { msg := protomem.Message{} err := msg.Unmarshal(bz) if err != nil { return TxsMessage{}, err } - var message TxsMessage + return msgFromProto(&msg) +} +func msgFromProto(m proto.Message) (TxsMessage, error) { + msg := m.(*protomem.Message) + var message TxsMessage if i, ok := msg.Sum.(*protomem.Message_Txs); ok { txs := i.Txs.GetTxs() diff --git a/mempool/v1/reactor.go b/mempool/v1/reactor.go index 4da51bab8..e2f82ab84 100644 --- a/mempool/v1/reactor.go +++ b/mempool/v1/reactor.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/cosmos/gogoproto/proto" + cfg "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/libs/clist" "github.com/tendermint/tendermint/libs/log" @@ -133,6 +135,7 @@ func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { ID: mempool.MempoolChannel, Priority: 5, RecvMessageCapacity: batchMsg.Size(), + MessageType: &protomem.Message{}, }, } } @@ -153,18 +156,18 @@ func (memR *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { // Receive implements Reactor. // It adds any received transactions to the mempool. -func (memR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { - msg, err := memR.decodeMsg(msgBytes) +func (memR *Reactor) Receive(e p2p.Envelope) { + msg, err := protoToMsg(e.Message) if err != nil { - memR.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - memR.Switch.StopPeerForError(src, err) + memR.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + memR.Switch.StopPeerForError(e.Src, err) return } - memR.Logger.Debug("Receive", "src", src, "chId", chID, "msg", msg) + memR.Logger.Debug("Receive", "src", e.Src, "chId", e.ChannelID, "msg", msg) - txInfo := mempool.TxInfo{SenderID: memR.ids.GetForPeer(src)} - if src != nil { - txInfo.SenderP2PID = src.ID() + txInfo := mempool.TxInfo{SenderID: memR.ids.GetForPeer(e.Src)} + if e.Src != nil { + txInfo.SenderP2PID = e.Src.ID() } for _, tx := range msg.Txs { err = memR.mempool.CheckTx(tx, nil, txInfo) @@ -233,18 +236,14 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { // NOTE: Transaction batching was disabled due to // https://github.com/tendermint/tendermint/issues/5796 if !memTx.HasPeer(peerID) { - msg := protomem.Message{ - Sum: &protomem.Message_Txs{ - Txs: &protomem.Txs{Txs: [][]byte{memTx.tx}}, + success := peer.Send(p2p.Envelope{ + ChannelID: mempool.MempoolChannel, + Message: &protomem.Message{ + Sum: &protomem.Message_Txs{ + Txs: &protomem.Txs{Txs: [][]byte{memTx.tx}}, + }, }, - } - - bz, err := msg.Marshal() - if err != nil { - panic(err) - } - - success := peer.Send(mempool.MempoolChannel, bz) + }) if !success { time.Sleep(mempool.PeerCatchupSleepIntervalMS * time.Millisecond) continue @@ -268,13 +267,18 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { //----------------------------------------------------------------------------- // Messages -func (memR *Reactor) decodeMsg(bz []byte) (TxsMessage, error) { +func decodeMsg(bz []byte) (TxsMessage, error) { msg := protomem.Message{} err := msg.Unmarshal(bz) if err != nil { return TxsMessage{}, err } + return protoToMsg(&msg) +} + +func protoToMsg(m proto.Message) (TxsMessage, error) { + msg := m.(*protomem.Message) var message TxsMessage if i, ok := msg.Sum.(*protomem.Message_Txs); ok { diff --git a/p2p/base_reactor.go b/p2p/base_reactor.go index 86b0d980a..2792b6889 100644 --- a/p2p/base_reactor.go +++ b/p2p/base_reactor.go @@ -38,13 +38,13 @@ type Reactor interface { // or other reason). RemovePeer(peer Peer, reason interface{}) - // Receive is called by the switch when msgBytes is received from the peer. + // Receive is called by the switch when a message is received from the peer. // // NOTE reactor can not keep msgBytes around after Receive completes without // copying. // // CONTRACT: msgBytes are not nil. - Receive(chID byte, peer Peer, msgBytes []byte) + Receive(Envelope) } //-------------------------------------- @@ -64,8 +64,8 @@ func NewBaseReactor(name string, impl Reactor) *BaseReactor { func (br *BaseReactor) SetSwitch(sw *Switch) { br.Switch = sw } -func (*BaseReactor) GetChannels() []*conn.ChannelDescriptor { return nil } -func (*BaseReactor) AddPeer(peer Peer) {} -func (*BaseReactor) RemovePeer(peer Peer, reason interface{}) {} -func (*BaseReactor) Receive(chID byte, peer Peer, msgBytes []byte) {} -func (*BaseReactor) InitPeer(peer Peer) Peer { return peer } +func (*BaseReactor) GetChannels() []*conn.ChannelDescriptor { return nil } +func (*BaseReactor) AddPeer(peer Peer) {} +func (*BaseReactor) RemovePeer(peer Peer, reason interface{}) {} +func (*BaseReactor) Receive(e Envelope) {} +func (*BaseReactor) InitPeer(peer Peer) Peer { return peer } diff --git a/p2p/conn/connection.go b/p2p/conn/connection.go index f52fe73f7..3fd09059c 100644 --- a/p2p/conn/connection.go +++ b/p2p/conn/connection.go @@ -724,6 +724,7 @@ type ChannelDescriptor struct { SendQueueCapacity int RecvBufferCapacity int RecvMessageCapacity int + MessageType proto.Message } func (chDesc ChannelDescriptor) FillDefaults() (filled ChannelDescriptor) { diff --git a/p2p/metrics.gen.go b/p2p/metrics.gen.go index 98fb0121f..e452f1653 100644 --- a/p2p/metrics.gen.go +++ b/p2p/metrics.gen.go @@ -44,15 +44,29 @@ func PrometheusMetrics(namespace string, labelsAndValues ...string) *Metrics { Name: "num_txs", Help: "Number of transactions submitted by each peer.", }, append(labels, "peer_id")).With(labelsAndValues...), + MessageReceiveBytesTotal: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "message_receive_bytes_total", + Help: "Number of bytes of each message type received.", + }, append(labels, "message_type")).With(labelsAndValues...), + MessageSendBytesTotal: prometheus.NewCounterFrom(stdprometheus.CounterOpts{ + Namespace: namespace, + Subsystem: MetricsSubsystem, + Name: "message_send_bytes_total", + Help: "Number of bytes of each message type sent.", + }, append(labels, "message_type")).With(labelsAndValues...), } } func NopMetrics() *Metrics { return &Metrics{ - Peers: discard.NewGauge(), - PeerReceiveBytesTotal: discard.NewCounter(), - PeerSendBytesTotal: discard.NewCounter(), - PeerPendingSendBytes: discard.NewGauge(), - NumTxs: discard.NewGauge(), + Peers: discard.NewGauge(), + PeerReceiveBytesTotal: discard.NewCounter(), + PeerSendBytesTotal: discard.NewCounter(), + PeerPendingSendBytes: discard.NewGauge(), + NumTxs: discard.NewGauge(), + MessageReceiveBytesTotal: discard.NewCounter(), + MessageSendBytesTotal: discard.NewCounter(), } } diff --git a/p2p/metrics.go b/p2p/metrics.go index 7e21870c7..fe33a2f41 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -24,4 +24,8 @@ type Metrics struct { PeerPendingSendBytes metrics.Gauge `metrics_labels:"peer_id"` // Number of transactions submitted by each peer. NumTxs metrics.Gauge `metrics_labels:"peer_id"` + // Number of bytes of each message type received. + MessageReceiveBytesTotal metrics.Counter `metrics_labels:"message_type"` + // Number of bytes of each message type sent. + MessageSendBytesTotal metrics.Counter `metrics_labels:"message_type"` } diff --git a/p2p/mock/peer.go b/p2p/mock/peer.go index 10254c343..47117270b 100644 --- a/p2p/mock/peer.go +++ b/p2p/mock/peer.go @@ -42,9 +42,9 @@ func NewPeer(ip net.IP) *Peer { return mp } -func (mp *Peer) FlushStop() { mp.Stop() } //nolint:errcheck //ignore error -func (mp *Peer) TrySend(chID byte, msgBytes []byte) bool { return true } -func (mp *Peer) Send(chID byte, msgBytes []byte) bool { return true } +func (mp *Peer) FlushStop() { mp.Stop() } //nolint:errcheck //ignore error +func (mp *Peer) TrySend(e p2p.Envelope) bool { return true } +func (mp *Peer) Send(e p2p.Envelope) bool { return true } func (mp *Peer) NodeInfo() p2p.NodeInfo { return p2p.DefaultNodeInfo{ DefaultNodeID: mp.addr.ID, diff --git a/p2p/mocks/peer.go b/p2p/mocks/peer.go index a9151c7d8..0850ab588 100644 --- a/p2p/mocks/peer.go +++ b/p2p/mocks/peer.go @@ -234,13 +234,13 @@ func (_m *Peer) Reset() error { return r0 } -// Send provides a mock function with given fields: _a0, _a1 -func (_m *Peer) Send(_a0 byte, _a1 []byte) bool { - ret := _m.Called(_a0, _a1) +// Send provides a mock function with given fields: _a0 +func (_m *Peer) Send(_a0 p2p.Envelope) bool { + ret := _m.Called(_a0) var r0 bool - if rf, ok := ret.Get(0).(func(byte, []byte) bool); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(p2p.Envelope) bool); ok { + r0 = rf(_a0) } else { r0 = ret.Get(0).(bool) } @@ -335,13 +335,13 @@ func (_m *Peer) String() string { return r0 } -// TrySend provides a mock function with given fields: _a0, _a1 -func (_m *Peer) TrySend(_a0 byte, _a1 []byte) bool { - ret := _m.Called(_a0, _a1) +// TrySend provides a mock function with given fields: _a0 +func (_m *Peer) TrySend(_a0 p2p.Envelope) bool { + ret := _m.Called(_a0) var r0 bool - if rf, ok := ret.Get(0).(func(byte, []byte) bool); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(p2p.Envelope) bool); ok { + r0 = rf(_a0) } else { r0 = ret.Get(0).(bool) } diff --git a/p2p/peer.go b/p2p/peer.go index d8d61a7a0..11742fd66 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -5,6 +5,8 @@ import ( "net" "time" + "github.com/cosmos/gogoproto/proto" + "github.com/tendermint/tendermint/libs/cmap" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/service" @@ -34,8 +36,8 @@ type Peer interface { Status() tmconn.ConnectionStatus SocketAddr() *NetAddress // actual address of the socket - Send(byte, []byte) bool - TrySend(byte, []byte) bool + Send(Envelope) bool + TrySend(Envelope) bool Set(string, interface{}) Get(string) interface{} @@ -132,6 +134,7 @@ func newPeer( mConfig tmconn.MConnConfig, nodeInfo NodeInfo, reactorsByCh map[byte]Reactor, + msgTypeByChID map[byte]proto.Message, chDescs []*tmconn.ChannelDescriptor, onPeerError func(Peer, interface{}), options ...PeerOption, @@ -149,6 +152,7 @@ func newPeer( pc.conn, p, reactorsByCh, + msgTypeByChID, chDescs, onPeerError, mConfig, @@ -249,19 +253,23 @@ func (p *peer) Status() tmconn.ConnectionStatus { // Send msg bytes to the channel identified by chID byte. Returns false if the // send queue is full after timeout, specified by MConnection. -func (p *peer) Send(chID byte, msgBytes []byte) bool { +func (p *peer) Send(e Envelope) bool { if !p.IsRunning() { // see Switch#Broadcast, where we fetch the list of peers and loop over // them - while we're looping, one peer may be removed and stopped. return false - } else if !p.hasChannel(chID) { + } else if !p.hasChannel(e.ChannelID) { return false } - res := p.mconn.Send(chID, msgBytes) + msgBytes, err := proto.Marshal(e.Message) + if err != nil { + panic(err) // Q: should this panic or error? + } + res := p.mconn.Send(e.ChannelID, msgBytes) if res { labels := []string{ "peer_id", string(p.ID()), - "chID", fmt.Sprintf("%#x", chID), + "chID", fmt.Sprintf("%#x", e.ChannelID), } p.metrics.PeerSendBytesTotal.With(labels...).Add(float64(len(msgBytes))) } @@ -270,17 +278,21 @@ func (p *peer) Send(chID byte, msgBytes []byte) bool { // TrySend msg bytes to the channel identified by chID byte. Immediately returns // false if the send queue is full. -func (p *peer) TrySend(chID byte, msgBytes []byte) bool { +func (p *peer) TrySend(e Envelope) bool { if !p.IsRunning() { return false - } else if !p.hasChannel(chID) { + } else if !p.hasChannel(e.ChannelID) { return false } - res := p.mconn.TrySend(chID, msgBytes) + msgBytes, err := proto.Marshal(e.Message) + if err != nil { + panic(err) + } + res := p.mconn.TrySend(e.ChannelID, msgBytes) if res { labels := []string{ "peer_id", string(p.ID()), - "chID", fmt.Sprintf("%#x", chID), + "chID", fmt.Sprintf("%#x", e.ChannelID), } p.metrics.PeerSendBytesTotal.With(labels...).Add(float64(len(msgBytes))) } @@ -384,6 +396,7 @@ func createMConnection( conn net.Conn, p *peer, reactorsByCh map[byte]Reactor, + msgTypeByChID map[byte]proto.Message, chDescs []*tmconn.ChannelDescriptor, onPeerError func(Peer, interface{}), config tmconn.MConnConfig, @@ -396,12 +409,24 @@ func createMConnection( // which does onPeerError. panic(fmt.Sprintf("Unknown channel %X", chID)) } + mt := msgTypeByChID[chID] + msg := proto.Clone(mt) + err := proto.Unmarshal(msgBytes, msg) + if err != nil { + // TODO(williambanfield) add a log line + return + } labels := []string{ "peer_id", string(p.ID()), "chID", fmt.Sprintf("%#x", chID), } p.metrics.PeerReceiveBytesTotal.With(labels...).Add(float64(len(msgBytes))) - reactor.Receive(chID, p, msgBytes) + p.metrics.MessageReceiveBytesTotal.With("message_type", "tmp").Add(float64(len(msgBytes))) + reactor.Receive(Envelope{ + ChannelID: chID, + Src: p, + Message: msg, + }) } onError := func(r interface{}) { diff --git a/p2p/peer_set_test.go b/p2p/peer_set_test.go index db3d9261e..40a345424 100644 --- a/p2p/peer_set_test.go +++ b/p2p/peer_set_test.go @@ -18,22 +18,22 @@ type mockPeer struct { id ID } -func (mp *mockPeer) FlushStop() { mp.Stop() } //nolint:errcheck // ignore error -func (mp *mockPeer) TrySend(chID byte, msgBytes []byte) bool { return true } -func (mp *mockPeer) Send(chID byte, msgBytes []byte) bool { return true } -func (mp *mockPeer) NodeInfo() NodeInfo { return DefaultNodeInfo{} } -func (mp *mockPeer) Status() ConnectionStatus { return ConnectionStatus{} } -func (mp *mockPeer) ID() ID { return mp.id } -func (mp *mockPeer) IsOutbound() bool { return false } -func (mp *mockPeer) IsPersistent() bool { return true } -func (mp *mockPeer) Get(s string) interface{} { return s } -func (mp *mockPeer) Set(string, interface{}) {} -func (mp *mockPeer) RemoteIP() net.IP { return mp.ip } -func (mp *mockPeer) SocketAddr() *NetAddress { return nil } -func (mp *mockPeer) RemoteAddr() net.Addr { return &net.TCPAddr{IP: mp.ip, Port: 8800} } -func (mp *mockPeer) CloseConn() error { return nil } -func (mp *mockPeer) SetRemovalFailed() {} -func (mp *mockPeer) GetRemovalFailed() bool { return false } +func (mp *mockPeer) FlushStop() { mp.Stop() } //nolint:errcheck // ignore error +func (mp *mockPeer) TrySend(e Envelope) bool { return true } +func (mp *mockPeer) Send(e Envelope) bool { return true } +func (mp *mockPeer) NodeInfo() NodeInfo { return DefaultNodeInfo{} } +func (mp *mockPeer) Status() ConnectionStatus { return ConnectionStatus{} } +func (mp *mockPeer) ID() ID { return mp.id } +func (mp *mockPeer) IsOutbound() bool { return false } +func (mp *mockPeer) IsPersistent() bool { return true } +func (mp *mockPeer) Get(s string) interface{} { return s } +func (mp *mockPeer) Set(string, interface{}) {} +func (mp *mockPeer) RemoteIP() net.IP { return mp.ip } +func (mp *mockPeer) SocketAddr() *NetAddress { return nil } +func (mp *mockPeer) RemoteAddr() net.Addr { return &net.TCPAddr{IP: mp.ip, Port: 8800} } +func (mp *mockPeer) CloseConn() error { return nil } +func (mp *mockPeer) SetRemovalFailed() {} +func (mp *mockPeer) GetRemovalFailed() bool { return false } // Returns a mock peer func newMockPeer(ip net.IP) *mockPeer { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index f8808f14d..4551d2dba 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/cosmos/gogoproto/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,6 +15,7 @@ import ( "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/libs/bytes" "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/proto/tendermint/p2p" "github.com/tendermint/tendermint/config" tmconn "github.com/tendermint/tendermint/p2p/conn" @@ -70,7 +72,7 @@ func TestPeerSend(t *testing.T) { }) assert.True(p.CanSend(testCh)) - assert.True(p.Send(testCh, []byte("Asylum"))) + assert.True(p.Send(Envelope{ChannelID: testCh, Message: &p2p.Message{}})) } func createOutboundPeerAndPerformHandshake( @@ -82,6 +84,9 @@ func createOutboundPeerAndPerformHandshake( {ID: testCh, Priority: 1}, } reactorsByCh := map[byte]Reactor{testCh: NewTestReactor(chDescs, true)} + msgTypeByChID := map[byte]proto.Message{ + testCh: &p2p.Message{}, + } pk := ed25519.GenPrivKey() pc, err := testOutboundPeerConn(addr, config, false, pk) if err != nil { @@ -94,7 +99,7 @@ func createOutboundPeerAndPerformHandshake( return nil, err } - p := newPeer(pc, mConfig, peerNodeInfo, reactorsByCh, chDescs, func(p Peer, r interface{}) {}) + p := newPeer(pc, mConfig, peerNodeInfo, reactorsByCh, msgTypeByChID, chDescs, func(p Peer, r interface{}) {}) p.SetLogger(log.TestingLogger().With("peer", addr)) return p, nil } diff --git a/p2p/pex/pex_reactor.go b/p2p/pex/pex_reactor.go index 006f89cd7..1f21a9366 100644 --- a/p2p/pex/pex_reactor.go +++ b/p2p/pex/pex_reactor.go @@ -184,6 +184,7 @@ func (r *Reactor) GetChannels() []*conn.ChannelDescriptor { Priority: 1, SendQueueCapacity: 10, RecvMessageCapacity: maxMsgSize, + MessageType: &tmp2p.Message{}, }, } } @@ -236,14 +237,14 @@ func (r *Reactor) logErrAddrBook(err error) { } // Receive implements Reactor by handling incoming PEX messages. -func (r *Reactor) Receive(chID byte, src Peer, msgBytes []byte) { - msg, err := decodeMsg(msgBytes) +func (r *Reactor) Receive(e p2p.Envelope) { + msg, err := msgFromProto(e.Message) if err != nil { - r.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - r.Switch.StopPeerForError(src, err) + r.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + r.Switch.StopPeerForError(e.Src, err) return } - r.Logger.Debug("Received message", "src", src, "chId", chID, "msg", msg) + r.Logger.Debug("Received message", "src", e.Src, "chId", e.ChannelID, "msg", msg) switch msg := msg.(type) { case *tmp2p.PexRequest: @@ -255,8 +256,8 @@ func (r *Reactor) Receive(chID byte, src Peer, msgBytes []byte) { // If we're a seed and this is an inbound peer, // respond once and disconnect. - if r.config.SeedMode && !src.IsOutbound() { - id := string(src.ID()) + if r.config.SeedMode && !e.Src.IsOutbound() { + id := string(e.Src.ID()) v := r.lastReceivedRequests.Get(id) if v != nil { // FlushStop/StopPeer are already @@ -266,36 +267,36 @@ func (r *Reactor) Receive(chID byte, src Peer, msgBytes []byte) { r.lastReceivedRequests.Set(id, time.Now()) // Send addrs and disconnect - r.SendAddrs(src, r.book.GetSelectionWithBias(biasToSelectNewPeers)) + r.SendAddrs(e.Src, r.book.GetSelectionWithBias(biasToSelectNewPeers)) go func() { // In a go-routine so it doesn't block .Receive. - src.FlushStop() - r.Switch.StopPeerGracefully(src) + e.Src.FlushStop() + r.Switch.StopPeerGracefully(e.Src) }() } else { // Check we're not receiving requests too frequently. - if err := r.receiveRequest(src); err != nil { - r.Switch.StopPeerForError(src, err) - r.book.MarkBad(src.SocketAddr(), defaultBanTime) + if err := r.receiveRequest(e.Src); err != nil { + r.Switch.StopPeerForError(e.Src, err) + r.book.MarkBad(e.Src.SocketAddr(), defaultBanTime) return } - r.SendAddrs(src, r.book.GetSelection()) + r.SendAddrs(e.Src, r.book.GetSelection()) } case *tmp2p.PexAddrs: // If we asked for addresses, add them to the book addrs, err := p2p.NetAddressesFromProto(msg.Addrs) if err != nil { - r.Switch.StopPeerForError(src, err) - r.book.MarkBad(src.SocketAddr(), defaultBanTime) + r.Switch.StopPeerForError(e.Src, err) + r.book.MarkBad(e.Src.SocketAddr(), defaultBanTime) return } - err = r.ReceiveAddrs(addrs, src) + err = r.ReceiveAddrs(addrs, e.Src) if err != nil { - r.Switch.StopPeerForError(src, err) + r.Switch.StopPeerForError(e.Src, err) if err == ErrUnsolicitedList { - r.book.MarkBad(src.SocketAddr(), defaultBanTime) + r.book.MarkBad(e.Src.SocketAddr(), defaultBanTime) } return } @@ -348,7 +349,10 @@ func (r *Reactor) RequestAddrs(p Peer) { } r.Logger.Debug("Request addrs", "from", p) r.requestsSent.Set(id, struct{}{}) - p.Send(PexChannel, mustEncode(&tmp2p.PexRequest{})) + p.Send(p2p.Envelope{ + ChannelID: PexChannel, + Message: mustMsgToWrappedProto(&tmp2p.PexRequest{}), + }) } // ReceiveAddrs adds the given addrs to the addrbook if theres an open @@ -406,7 +410,11 @@ func (r *Reactor) ReceiveAddrs(addrs []*p2p.NetAddress, src Peer) error { // SendAddrs sends addrs to the peer. func (r *Reactor) SendAddrs(p Peer, netAddrs []*p2p.NetAddress) { - p.Send(PexChannel, mustEncode(&tmp2p.PexAddrs{Addrs: p2p.NetAddressesToProto(netAddrs)})) + e := p2p.Envelope{ + ChannelID: PexChannel, + Message: mustMsgToWrappedProto(&tmp2p.PexAddrs{Addrs: p2p.NetAddressesToProto(netAddrs)}), + } + p.Send(e) } // SetEnsurePeersPeriod sets period to ensure peers connected. @@ -769,6 +777,15 @@ func markAddrInBookBasedOnErr(addr *p2p.NetAddress, book AddrBook, err error) { // mustEncode proto encodes a tmp2p.Message func mustEncode(pb proto.Message) []byte { + msg := mustMsgToWrappedProto(pb) + bz, err := proto.Marshal(msg) + if err != nil { + panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) + } + return bz +} + +func mustMsgToWrappedProto(pb proto.Message) proto.Message { msg := tmp2p.Message{} switch pb := pb.(type) { case *tmp2p.PexRequest: @@ -778,12 +795,7 @@ func mustEncode(pb proto.Message) []byte { default: panic(fmt.Sprintf("Unknown message type %T", pb)) } - - bz, err := msg.Marshal() - if err != nil { - panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) - } - return bz + return &msg } func decodeMsg(bz []byte) (proto.Message, error) { @@ -793,7 +805,11 @@ func decodeMsg(bz []byte) (proto.Message, error) { if err != nil { return nil, err } + return msgFromProto(pb) +} +func msgFromProto(m proto.Message) (proto.Message, error) { + pb := m.(*tmp2p.Message) switch msg := pb.Sum.(type) { case *tmp2p.Message_PexRequest: return msg.PexRequest, nil diff --git a/p2p/pex/pex_reactor_test.go b/p2p/pex/pex_reactor_test.go index d5e052e91..8b3233a1f 100644 --- a/p2p/pex/pex_reactor_test.go +++ b/p2p/pex/pex_reactor_test.go @@ -131,12 +131,13 @@ func TestPEXReactorReceive(t *testing.T) { r.RequestAddrs(peer) size := book.Size() - msg := mustEncode(&tmp2p.PexAddrs{Addrs: []tmp2p.NetAddress{peer.SocketAddr().ToProto()}}) - r.Receive(PexChannel, peer, msg) + msg := &tmp2p.PexAddrs{Addrs: []tmp2p.NetAddress{peer.SocketAddr().ToProto()}} + r.Receive(PexChannel, peer, mustEncode(msg)) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(msg)}) assert.Equal(t, size+1, book.Size()) - msg = mustEncode(&tmp2p.PexRequest{}) - r.Receive(PexChannel, peer, msg) // should not panic. + r.Receive(PexChannel, peer, mustEncode(&tmp2p.PexRequest{})) // should not panic. + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(&tmp2p.PexRequest{})}) } func TestPEXReactorRequestMessageAbuse(t *testing.T) { @@ -159,16 +160,19 @@ func TestPEXReactorRequestMessageAbuse(t *testing.T) { // first time creates the entry r.Receive(PexChannel, peer, msg) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(&tmp2p.PexRequest{})}) assert.True(t, r.lastReceivedRequests.Has(id)) assert.True(t, sw.Peers().Has(peer.ID())) // next time sets the last time value r.Receive(PexChannel, peer, msg) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(&tmp2p.PexRequest{})}) assert.True(t, r.lastReceivedRequests.Has(id)) assert.True(t, sw.Peers().Has(peer.ID())) // third time is too many too soon - peer is removed r.Receive(PexChannel, peer, msg) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(&tmp2p.PexRequest{})}) assert.False(t, r.lastReceivedRequests.Has(id)) assert.False(t, sw.Peers().Has(peer.ID())) assert.True(t, book.IsBanned(peerAddr)) @@ -192,15 +196,17 @@ func TestPEXReactorAddrsMessageAbuse(t *testing.T) { assert.True(t, r.requestsSent.Has(id)) assert.True(t, sw.Peers().Has(peer.ID())) - msg := mustEncode(&tmp2p.PexAddrs{Addrs: []tmp2p.NetAddress{peer.SocketAddr().ToProto()}}) + msg := &tmp2p.PexAddrs{Addrs: []tmp2p.NetAddress{peer.SocketAddr().ToProto()}} // receive some addrs. should clear the request - r.Receive(PexChannel, peer, msg) + r.Receive(PexChannel, peer, mustEncode(msg)) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(msg)}) assert.False(t, r.requestsSent.Has(id)) assert.True(t, sw.Peers().Has(peer.ID())) // receiving more unsolicited addrs causes a disconnect and ban - r.Receive(PexChannel, peer, msg) + r.Receive(PexChannel, peer, mustEncode(msg)) + r.NewReceive(p2p.Envelope{ChannelID: PexChannel, Src: peer, Message: mustMsgToWrappedProto(msg)}) assert.False(t, sw.Peers().Has(peer.ID())) assert.True(t, book.IsBanned(peer.SocketAddr())) } diff --git a/p2p/switch.go b/p2p/switch.go index 884fd883e..8f0ae3fe5 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -6,9 +6,9 @@ import ( "sync" "time" + "github.com/cosmos/gogoproto/proto" "github.com/tendermint/tendermint/config" "github.com/tendermint/tendermint/libs/cmap" - "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/rand" "github.com/tendermint/tendermint/libs/service" "github.com/tendermint/tendermint/p2p/conn" @@ -69,16 +69,17 @@ type PeerFilterFunc func(IPeerSet, Peer) error type Switch struct { service.BaseService - config *config.P2PConfig - reactors map[string]Reactor - chDescs []*conn.ChannelDescriptor - reactorsByCh map[byte]Reactor - peers *PeerSet - dialing *cmap.CMap - reconnecting *cmap.CMap - nodeInfo NodeInfo // our node info - nodeKey *NodeKey // our node privkey - addrBook AddrBook + config *config.P2PConfig + reactors map[string]Reactor + chDescs []*conn.ChannelDescriptor + reactorsByCh map[byte]Reactor + msgTypeByChID map[byte]proto.Message + peers *PeerSet + dialing *cmap.CMap + reconnecting *cmap.CMap + nodeInfo NodeInfo // our node info + nodeKey *NodeKey // our node privkey + addrBook AddrBook // peers addresses with whom we'll maintain constant connection persistentPeersAddrs []*NetAddress unconditionalPeerIDs map[ID]struct{} @@ -113,6 +114,7 @@ func NewSwitch( reactors: make(map[string]Reactor), chDescs: make([]*conn.ChannelDescriptor, 0), reactorsByCh: make(map[byte]Reactor), + msgTypeByChID: make(map[byte]proto.Message), peers: NewPeerSet(), dialing: cmap.NewCMap(), reconnecting: cmap.NewCMap(), @@ -164,6 +166,7 @@ func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor { } sw.chDescs = append(sw.chDescs, chDesc) sw.reactorsByCh[chID] = reactor + sw.msgTypeByChID[chID] = chDesc.MessageType } sw.reactors[name] = reactor reactor.SetSwitch(sw) @@ -182,6 +185,7 @@ func (sw *Switch) RemoveReactor(name string, reactor Reactor) { } } delete(sw.reactorsByCh, chDesc.ID) + delete(sw.msgTypeByChID, chDesc.ID) } delete(sw.reactors, name) reactor.SetSwitch(nil) @@ -261,6 +265,7 @@ func (sw *Switch) OnStop() { // closed once msg bytes are sent to all peers (or time out). // // NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. +/* func (sw *Switch) Broadcast(chID byte, msgBytes []byte) chan bool { sw.Logger.Debug("Broadcast", "channel", chID, "msgBytes", log.NewLazySprintf("%X", msgBytes)) @@ -284,6 +289,37 @@ func (sw *Switch) Broadcast(chID byte, msgBytes []byte) chan bool { return successChan } +*/ + +// NewBroadcast runs a go routine for each attempted send, which will block trying +// to send for defaultSendTimeoutSeconds. Returns a channel which receives +// success values for each attempted send (false if times out). Channel will be +// closed once msg bytes are sent to all peers (or time out). +// +// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. +func (sw *Switch) NewBroadcast(e Envelope) chan bool { + sw.Logger.Debug("Broadcast", "channel", e.ChannelID) + + peers := sw.peers.List() + var wg sync.WaitGroup + wg.Add(len(peers)) + successChan := make(chan bool, len(peers)) + + for _, peer := range peers { + go func(p Peer) { + defer wg.Done() + success := p.Send(e) + successChan <- success + }(peer) + } + + go func() { + wg.Wait() + close(successChan) + }() + + return successChan +} // NumPeers returns the count of outbound/inbound and outbound-dialing peers. // unconditional peers are not counted here. @@ -623,11 +659,12 @@ func (sw *Switch) IsPeerPersistent(na *NetAddress) bool { func (sw *Switch) acceptRoutine() { for { p, err := sw.transport.Accept(peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - reactorsByCh: sw.reactorsByCh, - metrics: sw.metrics, - isPersistent: sw.IsPeerPersistent, + chDescs: sw.chDescs, + onPeerError: sw.StopPeerForError, + reactorsByCh: sw.reactorsByCh, + msgTypeByChID: sw.msgTypeByChID, + metrics: sw.metrics, + isPersistent: sw.IsPeerPersistent, }) if err != nil { switch err := err.(type) { @@ -726,11 +763,12 @@ func (sw *Switch) addOutboundPeerWithConfig( } p, err := sw.transport.Dial(*addr, peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.IsPeerPersistent, - reactorsByCh: sw.reactorsByCh, - metrics: sw.metrics, + chDescs: sw.chDescs, + onPeerError: sw.StopPeerForError, + isPersistent: sw.IsPeerPersistent, + reactorsByCh: sw.reactorsByCh, + msgTypeByChID: sw.msgTypeByChID, + metrics: sw.metrics, }) if err != nil { if e, ok := err.(ErrRejected); ok { diff --git a/p2p/switch_test.go b/p2p/switch_test.go index 9d5466df7..5df2c4798 100644 --- a/p2p/switch_test.go +++ b/p2p/switch_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,6 +24,8 @@ import ( "github.com/tendermint/tendermint/libs/log" tmsync "github.com/tendermint/tendermint/libs/sync" "github.com/tendermint/tendermint/p2p/conn" + "github.com/tendermint/tendermint/proto/tendermint/p2p" + p2pproto "github.com/tendermint/tendermint/proto/tendermint/p2p" ) var ( @@ -135,24 +138,59 @@ func TestSwitches(t *testing.T) { } // Lets send some messages - ch0Msg := []byte("channel zero") - ch1Msg := []byte("channel foo") - ch2Msg := []byte("channel bar") - - s1.Broadcast(byte(0x00), ch0Msg) - s1.Broadcast(byte(0x01), ch1Msg) - s1.Broadcast(byte(0x02), ch2Msg) + ch0Msg := &p2pproto.Message{ + Sum: &p2pproto.Message_PexAddrs{ + PexAddrs: &p2pproto.PexAddrs{ + Addrs: []p2p.NetAddress{ + { + ID: "0", + }, + }, + }, + }, + } + ch1Msg := &p2pproto.Message{ + Sum: &p2pproto.Message_PexAddrs{ + PexAddrs: &p2pproto.PexAddrs{ + Addrs: []p2p.NetAddress{ + { + ID: "1", + }, + }, + }, + }, + } + ch2Msg := &p2pproto.Message{ + Sum: &p2pproto.Message_PexAddrs{ + PexAddrs: &p2pproto.PexAddrs{ + Addrs: []p2p.NetAddress{ + { + ID: "2", + }, + }, + }, + }, + } + s1.NewBroadcast(Envelope{ChannelID: byte(0x00), Message: ch0Msg}) + s1.NewBroadcast(Envelope{ChannelID: byte(0x01), Message: ch1Msg}) + s1.NewBroadcast(Envelope{ChannelID: byte(0x02), Message: ch2Msg}) + msgBytes, err := proto.Marshal(ch0Msg) + require.NoError(t, err) assertMsgReceivedWithTimeout(t, - ch0Msg, + msgBytes, byte(0x00), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) + msgBytes, err = proto.Marshal(ch1Msg) + require.NoError(t, err) assertMsgReceivedWithTimeout(t, - ch1Msg, + msgBytes, byte(0x01), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) + msgBytes, err = proto.Marshal(ch2Msg) + require.NoError(t, err) assertMsgReceivedWithTimeout(t, - ch2Msg, + msgBytes, byte(0x02), s2.Reactor("bar").(*TestReactor), 10*time.Millisecond, 5*time.Second) } @@ -429,7 +467,10 @@ func TestSwitchStopPeerForError(t *testing.T) { // send messages to the peer from sw1 p := sw1.Peers().List()[0] - p.Send(0x1, []byte("here's a message to send")) + p.Send(Envelope{ + ChannelID: 0x1, + Message: &p2p.Message{}, + }) // stop sw2. this should cause the p to fail, // which results in calling StopPeerForError internally @@ -824,7 +865,7 @@ func BenchmarkSwitchBroadcast(b *testing.B) { // Send random message from foo channel to another for i := 0; i < b.N; i++ { chID := byte(i % 4) - successChan := s1.Broadcast(chID, []byte("test data")) + successChan := s1.NewBroadcast(Envelope{ChannelID: chID}) for s := range successChan { if s { numSuccess++ diff --git a/p2p/test_util.go b/p2p/test_util.go index 4e56f0193..14af8c520 100644 --- a/p2p/test_util.go +++ b/p2p/test_util.go @@ -149,6 +149,7 @@ func (sw *Switch) addPeerWithConnection(conn net.Conn) error { MConnConfig(sw.config), ni, sw.reactorsByCh, + sw.msgTypeByChID, sw.chDescs, sw.StopPeerForError, ) diff --git a/p2p/transport.go b/p2p/transport.go index e6e19a901..c16376fe8 100644 --- a/p2p/transport.go +++ b/p2p/transport.go @@ -8,6 +8,7 @@ import ( "golang.org/x/net/netutil" + "github.com/cosmos/gogoproto/proto" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/protoio" "github.com/tendermint/tendermint/p2p/conn" @@ -47,9 +48,10 @@ type peerConfig struct { // isPersistent allows you to set a function, which, given socket address // (for outbound peers) OR self-reported address (for inbound peers), tells // if the peer is persistent or not. - isPersistent func(*NetAddress) bool - reactorsByCh map[byte]Reactor - metrics *Metrics + isPersistent func(*NetAddress) bool + reactorsByCh map[byte]Reactor + msgTypeByChID map[byte]proto.Message + metrics *Metrics } // Transport emits and connects to Peers. The implementation of Peer is left to @@ -519,6 +521,7 @@ func (mt *MultiplexTransport) wrapPeer( mt.mConfig, ni, cfg.reactorsByCh, + cfg.msgTypeByChID, cfg.chDescs, cfg.onPeerError, PeerMetrics(cfg.metrics), diff --git a/p2p/types.go b/p2p/types.go index b11765bb5..b1e5266c6 100644 --- a/p2p/types.go +++ b/p2p/types.go @@ -1,8 +1,16 @@ package p2p import ( + "github.com/cosmos/gogoproto/proto" "github.com/tendermint/tendermint/p2p/conn" ) type ChannelDescriptor = conn.ChannelDescriptor type ConnectionStatus = conn.ConnectionStatus + +// Envelope contains a message with sender routing info. +type Envelope struct { + Src Peer // sender (empty if outbound) + Message proto.Message // message payload + ChannelID byte +} diff --git a/statesync/messages.go b/statesync/messages.go index 901036a7a..5ac0b8f4d 100644 --- a/statesync/messages.go +++ b/statesync/messages.go @@ -18,6 +18,15 @@ const ( // mustEncodeMsg encodes a Protobuf message, panicing on error. func mustEncodeMsg(pb proto.Message) []byte { + msg := mustWrapToProto(pb) + bz, err := proto.Marshal(msg) + if err != nil { + panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) + } + return bz +} + +func mustWrapToProto(pb proto.Message) proto.Message { msg := ssproto.Message{} switch pb := pb.(type) { case *ssproto.ChunkRequest: @@ -31,11 +40,7 @@ func mustEncodeMsg(pb proto.Message) []byte { default: panic(fmt.Errorf("unknown message type %T", pb)) } - bz, err := msg.Marshal() - if err != nil { - panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) - } - return bz + return &msg } // decodeMsg decodes a Protobuf message. @@ -45,6 +50,10 @@ func decodeMsg(bz []byte) (proto.Message, error) { if err != nil { return nil, err } + return msgFromProto(pb) +} + +func msgFromProto(pb *ssproto.Message) (proto.Message, error) { switch msg := pb.Sum.(type) { case *ssproto.Message_ChunkRequest: return msg.ChunkRequest, nil diff --git a/statesync/reactor.go b/statesync/reactor.go index 8434b6adf..e43b0c2a4 100644 --- a/statesync/reactor.go +++ b/statesync/reactor.go @@ -66,12 +66,14 @@ func (r *Reactor) GetChannels() []*p2p.ChannelDescriptor { Priority: 5, SendQueueCapacity: 10, RecvMessageCapacity: snapshotMsgSize, + MessageType: &ssproto.Message{}, }, { ID: ChunkChannel, Priority: 3, SendQueueCapacity: 10, RecvMessageCapacity: chunkMsgSize, + MessageType: &ssproto.Message{}, }, } } @@ -100,25 +102,25 @@ func (r *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { } // Receive implements p2p.Reactor. -func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { +func (r *Reactor) Receive(e p2p.Envelope) { if !r.IsRunning() { return } - msg, err := decodeMsg(msgBytes) + msg, err := msgFromProto(e.Message.(*ssproto.Message)) if err != nil { - r.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - r.Switch.StopPeerForError(src, err) + r.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) + r.Switch.StopPeerForError(e.Src, err) return } err = validateMsg(msg) if err != nil { - r.Logger.Error("Invalid message", "peer", src, "msg", msg, "err", err) - r.Switch.StopPeerForError(src, err) + r.Logger.Error("Invalid message", "peer", e.Src, "msg", msg, "err", err) + r.Switch.StopPeerForError(e.Src, err) return } - switch chID { + switch e.ChannelID { case SnapshotChannel: switch msg := msg.(type) { case *ssproto.SnapshotsRequest: @@ -129,14 +131,17 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { } for _, snapshot := range snapshots { r.Logger.Debug("Advertising snapshot", "height", snapshot.Height, - "format", snapshot.Format, "peer", src.ID()) - src.Send(chID, mustEncodeMsg(&ssproto.SnapshotsResponse{ - Height: snapshot.Height, - Format: snapshot.Format, - Chunks: snapshot.Chunks, - Hash: snapshot.Hash, - Metadata: snapshot.Metadata, - })) + "format", snapshot.Format, "peer", e.Src.ID()) + e.Src.Send(p2p.Envelope{ + ChannelID: e.ChannelID, + Message: mustWrapToProto(&ssproto.SnapshotsResponse{ + Height: snapshot.Height, + Format: snapshot.Format, + Chunks: snapshot.Chunks, + Hash: snapshot.Hash, + Metadata: snapshot.Metadata, + }), + }) } case *ssproto.SnapshotsResponse: @@ -146,8 +151,8 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { r.Logger.Debug("Received unexpected snapshot, no state sync in progress") return } - r.Logger.Debug("Received snapshot", "height", msg.Height, "format", msg.Format, "peer", src.ID()) - _, err := r.syncer.AddSnapshot(src, &snapshot{ + r.Logger.Debug("Received snapshot", "height", msg.Height, "format", msg.Format, "peer", e.Src.ID()) + _, err := r.syncer.AddSnapshot(e.Src, &snapshot{ Height: msg.Height, Format: msg.Format, Chunks: msg.Chunks, @@ -157,7 +162,7 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { // TODO: We may want to consider punishing the peer for certain errors if err != nil { r.Logger.Error("Failed to add snapshot", "height", msg.Height, "format", msg.Format, - "peer", src.ID(), "err", err) + "peer", e.Src.ID(), "err", err) return } @@ -169,7 +174,7 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { switch msg := msg.(type) { case *ssproto.ChunkRequest: r.Logger.Debug("Received chunk request", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) + "chunk", msg.Index, "peer", e.Src.ID()) resp, err := r.conn.LoadSnapshotChunkSync(abci.RequestLoadSnapshotChunk{ Height: msg.Height, Format: msg.Format, @@ -181,30 +186,33 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { return } r.Logger.Debug("Sending chunk", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) - src.Send(ChunkChannel, mustEncodeMsg(&ssproto.ChunkResponse{ - Height: msg.Height, - Format: msg.Format, - Index: msg.Index, - Chunk: resp.Chunk, - Missing: resp.Chunk == nil, - })) + "chunk", msg.Index, "peer", e.Src.ID()) + e.Src.Send(p2p.Envelope{ + ChannelID: ChunkChannel, + Message: mustWrapToProto(&ssproto.ChunkResponse{ + Height: msg.Height, + Format: msg.Format, + Index: msg.Index, + Chunk: resp.Chunk, + Missing: resp.Chunk == nil, + }), + }) case *ssproto.ChunkResponse: r.mtx.RLock() defer r.mtx.RUnlock() if r.syncer == nil { - r.Logger.Debug("Received unexpected chunk, no state sync in progress", "peer", src.ID()) + r.Logger.Debug("Received unexpected chunk, no state sync in progress", "peer", e.Src.ID()) return } r.Logger.Debug("Received chunk, adding to sync", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) + "chunk", msg.Index, "peer", e.Src.ID()) _, err := r.syncer.AddChunk(&chunk{ Height: msg.Height, Format: msg.Format, Index: msg.Index, Chunk: msg.Chunk, - Sender: src.ID(), + Sender: e.Src.ID(), }) if err != nil { r.Logger.Error("Failed to add chunk", "height", msg.Height, "format", msg.Format, @@ -217,7 +225,7 @@ func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { } default: - r.Logger.Error("Received message on invalid channel %x", chID) + r.Logger.Error("Received message on invalid channel %x", e.ChannelID) } } @@ -269,7 +277,11 @@ func (r *Reactor) Sync(stateProvider StateProvider, discoveryTime time.Duration) hook := func() { r.Logger.Debug("Requesting snapshots from known peers") // Request snapshots from all currently connected peers - r.Switch.Broadcast(SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})) + + r.Switch.NewBroadcast(p2p.Envelope{ + ChannelID: SnapshotChannel, + Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + }) } hook() diff --git a/statesync/reactor_test.go b/statesync/reactor_test.go index 053b47ef5..01a1c97ee 100644 --- a/statesync/reactor_test.go +++ b/statesync/reactor_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/cosmos/gogoproto/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -53,10 +54,18 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { peer.On("ID").Return(p2p.ID("id")) var response *ssproto.ChunkResponse if tc.expectResponse != nil { - peer.On("Send", ChunkChannel, mock.Anything).Run(func(args mock.Arguments) { - msg, err := decodeMsg(args[1].([]byte)) + peer.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + return ok && e.ChannelID == ChunkChannel + })).Run(func(args mock.Arguments) { + e := args[0].(p2p.Envelope) + + // Marshal to simulate a wire roundtrip. + bz, err := proto.Marshal(e.Message) require.NoError(t, err) - response = msg.(*ssproto.ChunkResponse) + err = proto.Unmarshal(bz, e.Message) + require.NoError(t, err) + response = e.Message.(*ssproto.Message).GetChunkResponse() }).Return(true) } @@ -71,7 +80,11 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { } }) - r.Receive(ChunkChannel, peer, mustEncodeMsg(tc.request)) + r.NewReceive(p2p.Envelope{ + ChannelID: ChunkChannel, + Src: peer, + Message: mustWrapToProto(tc.request), + }) time.Sleep(100 * time.Millisecond) assert.Equal(t, tc.expectResponse, response) @@ -131,10 +144,18 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) { peer := &p2pmocks.Peer{} if len(tc.expectResponses) > 0 { peer.On("ID").Return(p2p.ID("id")) - peer.On("Send", SnapshotChannel, mock.Anything).Run(func(args mock.Arguments) { - msg, err := decodeMsg(args[1].([]byte)) + peer.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + return ok && e.ChannelID == SnapshotChannel + })).Run(func(args mock.Arguments) { + e := args[0].(p2p.Envelope) + + // Marshal to simulate a wire roundtrip. + bz, err := proto.Marshal(e.Message) require.NoError(t, err) - responses = append(responses, msg.(*ssproto.SnapshotsResponse)) + err = proto.Unmarshal(bz, e.Message) + require.NoError(t, err) + responses = append(responses, e.Message.(*ssproto.Message).GetSnapshotsResponse()) }).Return(true) } @@ -149,6 +170,11 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) { } }) + r.NewReceive(p2p.Envelope{ + ChannelID: SnapshotChannel, + Src: peer, + Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + }) r.Receive(SnapshotChannel, peer, mustEncodeMsg(&ssproto.SnapshotsRequest{})) time.Sleep(100 * time.Millisecond) assert.Equal(t, tc.expectResponses, responses) diff --git a/statesync/syncer.go b/statesync/syncer.go index 7cb9f2946..d1d2aef39 100644 --- a/statesync/syncer.go +++ b/statesync/syncer.go @@ -126,7 +126,11 @@ func (s *syncer) AddSnapshot(peer p2p.Peer, snapshot *snapshot) (bool, error) { // to discover snapshots, later we may want to do retries and stuff. func (s *syncer) AddPeer(peer p2p.Peer) { s.logger.Debug("Requesting snapshots from peer", "peer", peer.ID()) - peer.Send(SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})) + e := p2p.Envelope{ + ChannelID: SnapshotChannel, + Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + } + peer.Send(e) } // RemovePeer removes a peer from the pool. @@ -467,11 +471,14 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { } s.logger.Debug("Requesting snapshot chunk", "height", snapshot.Height, "format", snapshot.Format, "chunk", chunk, "peer", peer.ID()) - peer.Send(ChunkChannel, mustEncodeMsg(&ssproto.ChunkRequest{ - Height: snapshot.Height, - Format: snapshot.Format, - Index: chunk, - })) + peer.Send(p2p.Envelope{ + ChannelID: ChunkChannel, + Message: mustWrapToProto(&ssproto.ChunkRequest{ + Height: snapshot.Height, + Format: snapshot.Format, + Index: chunk, + }), + }) } // verifyApp verifies the sync, checking the app hash, last block height and app version diff --git a/statesync/syncer_test.go b/statesync/syncer_test.go index 4dabe7288..2e2902d92 100644 --- a/statesync/syncer_test.go +++ b/statesync/syncer_test.go @@ -98,13 +98,21 @@ func TestSyncer_SyncAny(t *testing.T) { // Adding a couple of peers should trigger snapshot discovery messages peerA := &p2pmocks.Peer{} peerA.On("ID").Return(p2p.ID("a")) - peerA.On("Send", SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})).Return(true) + peerA.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + req := e.Message.(*ssproto.Message).GetSnapshotsRequest() + return ok && e.ChannelID == SnapshotChannel && req != nil + })).Return(true) syncer.AddPeer(peerA) peerA.AssertExpectations(t) peerB := &p2pmocks.Peer{} peerB.On("ID").Return(p2p.ID("b")) - peerB.On("Send", SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})).Return(true) + peerB.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + req := e.Message.(*ssproto.Message).GetSnapshotsRequest() + return ok && e.ChannelID == SnapshotChannel && req != nil + })).Return(true) syncer.AddPeer(peerB) peerB.AssertExpectations(t) @@ -147,9 +155,9 @@ func TestSyncer_SyncAny(t *testing.T) { chunkRequests := make(map[uint32]int) chunkRequestsMtx := tmsync.Mutex{} onChunkRequest := func(args mock.Arguments) { - pb, err := decodeMsg(args[1].([]byte)) - require.NoError(t, err) - msg := pb.(*ssproto.ChunkRequest) + e, ok := args[0].(p2p.Envelope) + require.True(t, ok) + msg := e.Message.(*ssproto.Message).GetChunkRequest() require.EqualValues(t, 1, msg.Height) require.EqualValues(t, 1, msg.Format) require.LessOrEqual(t, msg.Index, uint32(len(chunks))) @@ -162,8 +170,14 @@ func TestSyncer_SyncAny(t *testing.T) { chunkRequests[msg.Index]++ chunkRequestsMtx.Unlock() } - peerA.On("Send", ChunkChannel, mock.Anything).Maybe().Run(onChunkRequest).Return(true) - peerB.On("Send", ChunkChannel, mock.Anything).Maybe().Run(onChunkRequest).Return(true) + peerA.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + return ok && e.ChannelID == ChunkChannel + })).Maybe().Run(onChunkRequest).Return(true) + peerB.On("Send", mock.MatchedBy(func(i interface{}) bool { + e, ok := i.(p2p.Envelope) + return ok && e.ChannelID == ChunkChannel + })).Maybe().Run(onChunkRequest).Return(true) // The first time we're applying chunk 2 we tell it to retry the snapshot and discard chunk 1, // which should cause it to keep the existing chunk 0 and 2, and restart restoration from