statesync reactor users wrapper

This commit is contained in:
William Banfield
2022-10-21 12:59:12 -04:00
parent 36decbb4c8
commit 5b86c5562a
7 changed files with 33 additions and 83 deletions

View File

@@ -284,7 +284,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) {
ChannelID: mempool.MempoolChannel,
Src: peer,
Message: &protomem.Txs{
Txs: [][]byte{[]byte{0x01, 0x02, 0x03}},
Txs: [][]byte{{0x01, 0x02, 0x03}},
}})
reactor.AddPeer(peer)
}

View File

@@ -16,58 +16,6 @@ const (
chunkMsgSize = int(16e6)
)
// mustEncodeMsg encodes a Protobuf message, panicing on error.
func mustEncodeMsg(pb proto.Message) []byte {
msg := mustWrapToProto(pb)
bz, err := proto.Marshal(msg)
if err != nil {
panic(fmt.Errorf("unable to marshal %T: %w", pb, err))
}
return bz
}
func mustWrapToProto(pb proto.Message) proto.Message {
msg := ssproto.Message{}
switch pb := pb.(type) {
case *ssproto.ChunkRequest:
msg.Sum = &ssproto.Message_ChunkRequest{ChunkRequest: pb}
case *ssproto.ChunkResponse:
msg.Sum = &ssproto.Message_ChunkResponse{ChunkResponse: pb}
case *ssproto.SnapshotsRequest:
msg.Sum = &ssproto.Message_SnapshotsRequest{SnapshotsRequest: pb}
case *ssproto.SnapshotsResponse:
msg.Sum = &ssproto.Message_SnapshotsResponse{SnapshotsResponse: pb}
default:
panic(fmt.Errorf("unknown message type %T", pb))
}
return &msg
}
// decodeMsg decodes a Protobuf message.
func decodeMsg(bz []byte) (proto.Message, error) {
pb := &ssproto.Message{}
err := proto.Unmarshal(bz, pb)
if err != nil {
return nil, err
}
return msgFromProto(pb)
}
func msgFromProto(pb *ssproto.Message) (proto.Message, error) {
switch msg := pb.Sum.(type) {
case *ssproto.Message_ChunkRequest:
return msg.ChunkRequest, nil
case *ssproto.Message_ChunkResponse:
return msg.ChunkResponse, nil
case *ssproto.Message_SnapshotsRequest:
return msg.SnapshotsRequest, nil
case *ssproto.Message_SnapshotsResponse:
return msg.SnapshotsResponse, nil
default:
return nil, fmt.Errorf("unknown message type %T", msg)
}
}
// validateMsg validates a message.
func validateMsg(pb proto.Message) error {
if pb == nil {

View File

@@ -7,6 +7,7 @@ import (
"github.com/cosmos/gogoproto/proto"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/p2p"
ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
)
@@ -99,8 +100,10 @@ func TestStateSyncVectors(t *testing.T) {
for _, tc := range testCases {
tc := tc
bz := mustEncodeMsg(tc.msg)
w, err := tc.msg.(p2p.Wrapper).Wrap()
require.NoError(t, err)
bz, err := proto.Marshal(w)
require.NoError(t, err)
require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName)
}

View File

@@ -107,22 +107,16 @@ func (r *Reactor) Receive(e p2p.Envelope) {
return
}
msg, err := msgFromProto(e.Message.(*ssproto.Message))
err := validateMsg(e.Message)
if err != nil {
r.Logger.Error("Error decoding message", "src", e.Src, "chId", e.ChannelID, "err", err)
r.Switch.StopPeerForError(e.Src, err)
return
}
err = validateMsg(msg)
if err != nil {
r.Logger.Error("Invalid message", "peer", e.Src, "msg", msg, "err", err)
r.Logger.Error("Invalid message", "peer", e.Src, "msg", e.Message, "err", err)
r.Switch.StopPeerForError(e.Src, err)
return
}
switch e.ChannelID {
case SnapshotChannel:
switch msg := msg.(type) {
switch msg := e.Message.(type) {
case *ssproto.SnapshotsRequest:
snapshots, err := r.recentSnapshots(recentSnapshots)
if err != nil {
@@ -134,13 +128,13 @@ func (r *Reactor) Receive(e p2p.Envelope) {
"format", snapshot.Format, "peer", e.Src.ID())
e.Src.Send(p2p.Envelope{
ChannelID: e.ChannelID,
Message: mustWrapToProto(&ssproto.SnapshotsResponse{
Message: &ssproto.SnapshotsResponse{
Height: snapshot.Height,
Format: snapshot.Format,
Chunks: snapshot.Chunks,
Hash: snapshot.Hash,
Metadata: snapshot.Metadata,
}),
},
})
}
@@ -171,7 +165,7 @@ func (r *Reactor) Receive(e p2p.Envelope) {
}
case ChunkChannel:
switch msg := msg.(type) {
switch msg := e.Message.(type) {
case *ssproto.ChunkRequest:
r.Logger.Debug("Received chunk request", "height", msg.Height, "format", msg.Format,
"chunk", msg.Index, "peer", e.Src.ID())
@@ -189,13 +183,13 @@ func (r *Reactor) Receive(e p2p.Envelope) {
"chunk", msg.Index, "peer", e.Src.ID())
e.Src.Send(p2p.Envelope{
ChannelID: ChunkChannel,
Message: mustWrapToProto(&ssproto.ChunkResponse{
Message: &ssproto.ChunkResponse{
Height: msg.Height,
Format: msg.Format,
Index: msg.Index,
Chunk: resp.Chunk,
Missing: resp.Chunk == nil,
}),
},
})
case *ssproto.ChunkResponse:
@@ -280,7 +274,7 @@ func (r *Reactor) Sync(stateProvider StateProvider, discoveryTime time.Duration)
r.Switch.NewBroadcast(p2p.Envelope{
ChannelID: SnapshotChannel,
Message: mustWrapToProto(&ssproto.SnapshotsRequest{}),
Message: &ssproto.SnapshotsRequest{},
})
}

View File

@@ -65,7 +65,7 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) {
require.NoError(t, err)
err = proto.Unmarshal(bz, e.Message)
require.NoError(t, err)
response = e.Message.(*ssproto.Message).GetChunkResponse()
response = e.Message.(*ssproto.ChunkResponse)
}).Return(true)
}
@@ -80,10 +80,10 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) {
}
})
r.NewReceive(p2p.Envelope{
r.Receive(p2p.Envelope{
ChannelID: ChunkChannel,
Src: peer,
Message: mustWrapToProto(tc.request),
Message: tc.request,
})
time.Sleep(100 * time.Millisecond)
assert.Equal(t, tc.expectResponse, response)
@@ -155,7 +155,7 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) {
require.NoError(t, err)
err = proto.Unmarshal(bz, e.Message)
require.NoError(t, err)
responses = append(responses, e.Message.(*ssproto.Message).GetSnapshotsResponse())
responses = append(responses, e.Message.(*ssproto.SnapshotsResponse))
}).Return(true)
}
@@ -170,12 +170,11 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) {
}
})
r.NewReceive(p2p.Envelope{
r.Receive(p2p.Envelope{
ChannelID: SnapshotChannel,
Src: peer,
Message: mustWrapToProto(&ssproto.SnapshotsRequest{}),
Message: &ssproto.SnapshotsRequest{},
})
r.Receive(SnapshotChannel, peer, mustEncodeMsg(&ssproto.SnapshotsRequest{}))
time.Sleep(100 * time.Millisecond)
assert.Equal(t, tc.expectResponses, responses)

View File

@@ -128,7 +128,7 @@ func (s *syncer) AddPeer(peer p2p.Peer) {
s.logger.Debug("Requesting snapshots from peer", "peer", peer.ID())
e := p2p.Envelope{
ChannelID: SnapshotChannel,
Message: mustWrapToProto(&ssproto.SnapshotsRequest{}),
Message: &ssproto.SnapshotsRequest{},
}
peer.Send(e)
}
@@ -473,11 +473,11 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) {
"format", snapshot.Format, "chunk", chunk, "peer", peer.ID())
peer.Send(p2p.Envelope{
ChannelID: ChunkChannel,
Message: mustWrapToProto(&ssproto.ChunkRequest{
Message: &ssproto.ChunkRequest{
Height: snapshot.Height,
Format: snapshot.Format,
Index: chunk,
}),
},
})
}

View File

@@ -100,7 +100,10 @@ func TestSyncer_SyncAny(t *testing.T) {
peerA.On("ID").Return(p2p.ID("a"))
peerA.On("Send", mock.MatchedBy(func(i interface{}) bool {
e, ok := i.(p2p.Envelope)
req := e.Message.(*ssproto.Message).GetSnapshotsRequest()
if !ok {
return false
}
req, ok := e.Message.(*ssproto.SnapshotsRequest)
return ok && e.ChannelID == SnapshotChannel && req != nil
})).Return(true)
syncer.AddPeer(peerA)
@@ -110,7 +113,10 @@ func TestSyncer_SyncAny(t *testing.T) {
peerB.On("ID").Return(p2p.ID("b"))
peerB.On("Send", mock.MatchedBy(func(i interface{}) bool {
e, ok := i.(p2p.Envelope)
req := e.Message.(*ssproto.Message).GetSnapshotsRequest()
if !ok {
return false
}
req, ok := e.Message.(*ssproto.SnapshotsRequest)
return ok && e.ChannelID == SnapshotChannel && req != nil
})).Return(true)
syncer.AddPeer(peerB)
@@ -157,7 +163,7 @@ func TestSyncer_SyncAny(t *testing.T) {
onChunkRequest := func(args mock.Arguments) {
e, ok := args[0].(p2p.Envelope)
require.True(t, ok)
msg := e.Message.(*ssproto.Message).GetChunkRequest()
msg := e.Message.(*ssproto.ChunkRequest)
require.EqualValues(t, 1, msg.Height)
require.EqualValues(t, 1, msg.Format)
require.LessOrEqual(t, msg.Index, uint32(len(chunks)))