diff --git a/mempool/v0/reactor_test.go b/mempool/v0/reactor_test.go index e93880b4e..0719b4166 100644 --- a/mempool/v0/reactor_test.go +++ b/mempool/v0/reactor_test.go @@ -284,7 +284,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) { ChannelID: mempool.MempoolChannel, Src: peer, Message: &protomem.Txs{ - Txs: [][]byte{[]byte{0x01, 0x02, 0x03}}, + Txs: [][]byte{{0x01, 0x02, 0x03}}, }}) reactor.AddPeer(peer) } diff --git a/statesync/messages.go b/statesync/messages.go index 5ac0b8f4d..1de79f2e5 100644 --- a/statesync/messages.go +++ b/statesync/messages.go @@ -16,58 +16,6 @@ const ( chunkMsgSize = int(16e6) ) -// mustEncodeMsg encodes a Protobuf message, panicing on error. -func mustEncodeMsg(pb proto.Message) []byte { - msg := mustWrapToProto(pb) - bz, err := proto.Marshal(msg) - if err != nil { - panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) - } - return bz -} - -func mustWrapToProto(pb proto.Message) proto.Message { - msg := ssproto.Message{} - switch pb := pb.(type) { - case *ssproto.ChunkRequest: - msg.Sum = &ssproto.Message_ChunkRequest{ChunkRequest: pb} - case *ssproto.ChunkResponse: - msg.Sum = &ssproto.Message_ChunkResponse{ChunkResponse: pb} - case *ssproto.SnapshotsRequest: - msg.Sum = &ssproto.Message_SnapshotsRequest{SnapshotsRequest: pb} - case *ssproto.SnapshotsResponse: - msg.Sum = &ssproto.Message_SnapshotsResponse{SnapshotsResponse: pb} - default: - panic(fmt.Errorf("unknown message type %T", pb)) - } - return &msg -} - -// decodeMsg decodes a Protobuf message. -func decodeMsg(bz []byte) (proto.Message, error) { - pb := &ssproto.Message{} - err := proto.Unmarshal(bz, pb) - if err != nil { - return nil, err - } - return msgFromProto(pb) -} - -func msgFromProto(pb *ssproto.Message) (proto.Message, error) { - switch msg := pb.Sum.(type) { - case *ssproto.Message_ChunkRequest: - return msg.ChunkRequest, nil - case *ssproto.Message_ChunkResponse: - return msg.ChunkResponse, nil - case *ssproto.Message_SnapshotsRequest: - return msg.SnapshotsRequest, nil - case *ssproto.Message_SnapshotsResponse: - return msg.SnapshotsResponse, nil - default: - return nil, fmt.Errorf("unknown message type %T", msg) - } -} - // validateMsg validates a message. func validateMsg(pb proto.Message) error { if pb == nil { diff --git a/statesync/messages_test.go b/statesync/messages_test.go index 7bfdcb6ac..18dcaf748 100644 --- a/statesync/messages_test.go +++ b/statesync/messages_test.go @@ -7,6 +7,7 @@ import ( "github.com/cosmos/gogoproto/proto" "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/p2p" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) @@ -99,8 +100,10 @@ func TestStateSyncVectors(t *testing.T) { for _, tc := range testCases { tc := tc - - bz := mustEncodeMsg(tc.msg) + w, err := tc.msg.(p2p.Wrapper).Wrap() + require.NoError(t, err) + bz, err := proto.Marshal(w) + require.NoError(t, err) require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) } diff --git a/statesync/reactor.go b/statesync/reactor.go index e43b0c2a4..ec40ee6bd 100644 --- a/statesync/reactor.go +++ b/statesync/reactor.go @@ -107,22 +107,16 @@ func (r *Reactor) Receive(e p2p.Envelope) { return } - msg, err := msgFromProto(e.Message.(*ssproto.Message)) + err := validateMsg(e.Message) if err != nil { - r.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err) - r.Switch.StopPeerForError(e.Src, err) - return - } - err = validateMsg(msg) - if err != nil { - r.Logger.Error("Invalid message", "peer", e.Src, "msg", msg, "err", err) + r.Logger.Error("Invalid message", "peer", e.Src, "msg", e.Message, "err", err) r.Switch.StopPeerForError(e.Src, err) return } switch e.ChannelID { case SnapshotChannel: - switch msg := msg.(type) { + switch msg := e.Message.(type) { case *ssproto.SnapshotsRequest: snapshots, err := r.recentSnapshots(recentSnapshots) if err != nil { @@ -134,13 +128,13 @@ func (r *Reactor) Receive(e p2p.Envelope) { "format", snapshot.Format, "peer", e.Src.ID()) e.Src.Send(p2p.Envelope{ ChannelID: e.ChannelID, - Message: mustWrapToProto(&ssproto.SnapshotsResponse{ + Message: &ssproto.SnapshotsResponse{ Height: snapshot.Height, Format: snapshot.Format, Chunks: snapshot.Chunks, Hash: snapshot.Hash, Metadata: snapshot.Metadata, - }), + }, }) } @@ -171,7 +165,7 @@ func (r *Reactor) Receive(e p2p.Envelope) { } case ChunkChannel: - switch msg := msg.(type) { + switch msg := e.Message.(type) { case *ssproto.ChunkRequest: r.Logger.Debug("Received chunk request", "height", msg.Height, "format", msg.Format, "chunk", msg.Index, "peer", e.Src.ID()) @@ -189,13 +183,13 @@ func (r *Reactor) Receive(e p2p.Envelope) { "chunk", msg.Index, "peer", e.Src.ID()) e.Src.Send(p2p.Envelope{ ChannelID: ChunkChannel, - Message: mustWrapToProto(&ssproto.ChunkResponse{ + Message: &ssproto.ChunkResponse{ Height: msg.Height, Format: msg.Format, Index: msg.Index, Chunk: resp.Chunk, Missing: resp.Chunk == nil, - }), + }, }) case *ssproto.ChunkResponse: @@ -280,7 +274,7 @@ func (r *Reactor) Sync(stateProvider StateProvider, discoveryTime time.Duration) r.Switch.NewBroadcast(p2p.Envelope{ ChannelID: SnapshotChannel, - Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + Message: &ssproto.SnapshotsRequest{}, }) } diff --git a/statesync/reactor_test.go b/statesync/reactor_test.go index 01a1c97ee..8d06c7c2d 100644 --- a/statesync/reactor_test.go +++ b/statesync/reactor_test.go @@ -65,7 +65,7 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { require.NoError(t, err) err = proto.Unmarshal(bz, e.Message) require.NoError(t, err) - response = e.Message.(*ssproto.Message).GetChunkResponse() + response = e.Message.(*ssproto.ChunkResponse) }).Return(true) } @@ -80,10 +80,10 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { } }) - r.NewReceive(p2p.Envelope{ + r.Receive(p2p.Envelope{ ChannelID: ChunkChannel, Src: peer, - Message: mustWrapToProto(tc.request), + Message: tc.request, }) time.Sleep(100 * time.Millisecond) assert.Equal(t, tc.expectResponse, response) @@ -155,7 +155,7 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) { require.NoError(t, err) err = proto.Unmarshal(bz, e.Message) require.NoError(t, err) - responses = append(responses, e.Message.(*ssproto.Message).GetSnapshotsResponse()) + responses = append(responses, e.Message.(*ssproto.SnapshotsResponse)) }).Return(true) } @@ -170,12 +170,11 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) { } }) - r.NewReceive(p2p.Envelope{ + r.Receive(p2p.Envelope{ ChannelID: SnapshotChannel, Src: peer, - Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + Message: &ssproto.SnapshotsRequest{}, }) - r.Receive(SnapshotChannel, peer, mustEncodeMsg(&ssproto.SnapshotsRequest{})) time.Sleep(100 * time.Millisecond) assert.Equal(t, tc.expectResponses, responses) diff --git a/statesync/syncer.go b/statesync/syncer.go index d1d2aef39..6be091886 100644 --- a/statesync/syncer.go +++ b/statesync/syncer.go @@ -128,7 +128,7 @@ func (s *syncer) AddPeer(peer p2p.Peer) { s.logger.Debug("Requesting snapshots from peer", "peer", peer.ID()) e := p2p.Envelope{ ChannelID: SnapshotChannel, - Message: mustWrapToProto(&ssproto.SnapshotsRequest{}), + Message: &ssproto.SnapshotsRequest{}, } peer.Send(e) } @@ -473,11 +473,11 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { "format", snapshot.Format, "chunk", chunk, "peer", peer.ID()) peer.Send(p2p.Envelope{ ChannelID: ChunkChannel, - Message: mustWrapToProto(&ssproto.ChunkRequest{ + Message: &ssproto.ChunkRequest{ Height: snapshot.Height, Format: snapshot.Format, Index: chunk, - }), + }, }) } diff --git a/statesync/syncer_test.go b/statesync/syncer_test.go index 2e2902d92..100349eb3 100644 --- a/statesync/syncer_test.go +++ b/statesync/syncer_test.go @@ -100,7 +100,10 @@ func TestSyncer_SyncAny(t *testing.T) { peerA.On("ID").Return(p2p.ID("a")) peerA.On("Send", mock.MatchedBy(func(i interface{}) bool { e, ok := i.(p2p.Envelope) - req := e.Message.(*ssproto.Message).GetSnapshotsRequest() + if !ok { + return false + } + req, ok := e.Message.(*ssproto.SnapshotsRequest) return ok && e.ChannelID == SnapshotChannel && req != nil })).Return(true) syncer.AddPeer(peerA) @@ -110,7 +113,10 @@ func TestSyncer_SyncAny(t *testing.T) { peerB.On("ID").Return(p2p.ID("b")) peerB.On("Send", mock.MatchedBy(func(i interface{}) bool { e, ok := i.(p2p.Envelope) - req := e.Message.(*ssproto.Message).GetSnapshotsRequest() + if !ok { + return false + } + req, ok := e.Message.(*ssproto.SnapshotsRequest) return ok && e.ChannelID == SnapshotChannel && req != nil })).Return(true) syncer.AddPeer(peerB) @@ -157,7 +163,7 @@ func TestSyncer_SyncAny(t *testing.T) { onChunkRequest := func(args mock.Arguments) { e, ok := args[0].(p2p.Envelope) require.True(t, ok) - msg := e.Message.(*ssproto.Message).GetChunkRequest() + msg := e.Message.(*ssproto.ChunkRequest) require.EqualValues(t, 1, msg.Height) require.EqualValues(t, 1, msg.Format) require.LessOrEqual(t, msg.Index, uint32(len(chunks)))