diff --git a/blockchain/msgs.go b/blockchain/msgs.go index cd5ef977f..aa2c8988d 100644 --- a/blockchain/msgs.go +++ b/blockchain/msgs.go @@ -1,12 +1,6 @@ package blockchain import ( - "errors" - "fmt" - - "github.com/gogo/protobuf/proto" - - bcproto "github.com/tendermint/tendermint/proto/tendermint/blockchain" "github.com/tendermint/tendermint/types" ) @@ -18,91 +12,3 @@ const ( BlockResponseMessagePrefixSize + BlockResponseMessageFieldKeySize ) - -// EncodeMsg encodes a Protobuf message -func EncodeMsg(pb proto.Message) ([]byte, error) { - msg := bcproto.Message{} - - switch pb := pb.(type) { - case *bcproto.BlockRequest: - msg.Sum = &bcproto.Message_BlockRequest{BlockRequest: pb} - case *bcproto.BlockResponse: - msg.Sum = &bcproto.Message_BlockResponse{BlockResponse: pb} - case *bcproto.NoBlockResponse: - msg.Sum = &bcproto.Message_NoBlockResponse{NoBlockResponse: pb} - case *bcproto.StatusRequest: - msg.Sum = &bcproto.Message_StatusRequest{StatusRequest: pb} - case *bcproto.StatusResponse: - msg.Sum = &bcproto.Message_StatusResponse{StatusResponse: pb} - default: - return nil, fmt.Errorf("unknown message type %T", pb) - } - - bz, err := proto.Marshal(&msg) - if err != nil { - return nil, fmt.Errorf("unable to marshal %T: %w", pb, err) - } - - return bz, nil -} - -// DecodeMsg decodes a Protobuf message. -func DecodeMsg(bz []byte) (proto.Message, error) { - pb := &bcproto.Message{} - - err := proto.Unmarshal(bz, pb) - if err != nil { - return nil, err - } - - switch msg := pb.Sum.(type) { - case *bcproto.Message_BlockRequest: - return msg.BlockRequest, nil - case *bcproto.Message_BlockResponse: - return msg.BlockResponse, nil - case *bcproto.Message_NoBlockResponse: - return msg.NoBlockResponse, nil - case *bcproto.Message_StatusRequest: - return msg.StatusRequest, nil - case *bcproto.Message_StatusResponse: - return msg.StatusResponse, nil - default: - return nil, fmt.Errorf("unknown message type %T", msg) - } -} - -// ValidateMsg validates a message. -func ValidateMsg(pb proto.Message) error { - if pb == nil { - return errors.New("message cannot be nil") - } - - switch msg := pb.(type) { - case *bcproto.BlockRequest: - if msg.Height < 0 { - return errors.New("negative Height") - } - case *bcproto.BlockResponse: - // validate basic is called later when converting from proto - return nil - case *bcproto.NoBlockResponse: - if msg.Height < 0 { - return errors.New("negative Height") - } - case *bcproto.StatusResponse: - if msg.Base < 0 { - return errors.New("negative Base") - } - if msg.Height < 0 { - return errors.New("negative Height") - } - if msg.Base > msg.Height { - return fmt.Errorf("base %v cannot be greater than height %v", msg.Base, msg.Height) - } - case *bcproto.StatusRequest: - return nil - default: - return fmt.Errorf("unknown message type %T", msg) - } - return nil -} diff --git a/proto/tendermint/blockchain/message.go b/proto/tendermint/blockchain/message.go new file mode 100644 index 000000000..2832999d0 --- /dev/null +++ b/proto/tendermint/blockchain/message.go @@ -0,0 +1,99 @@ +package blockchain + +import ( + "errors" + fmt "fmt" + + proto "github.com/gogo/protobuf/proto" +) + +// Wrap implements the p2p Wrapper interface and wraps a blockchain messages. +func (m *Message) Wrap(pb proto.Message) error { + switch msg := pb.(type) { + case *BlockRequest: + m.Sum = &Message_BlockRequest{BlockRequest: msg} + + case *BlockResponse: + m.Sum = &Message_BlockResponse{BlockResponse: msg} + + case *NoBlockResponse: + m.Sum = &Message_NoBlockResponse{NoBlockResponse: msg} + + case *StatusRequest: + m.Sum = &Message_StatusRequest{StatusRequest: msg} + + case *StatusResponse: + m.Sum = &Message_StatusResponse{StatusResponse: msg} + + default: + return fmt.Errorf("unknown message: %T", msg) + } + + return nil +} + +// Unwrap implements the p2p Wrapper interface and unwraps a wrapped blockchain +// message. +func (m *Message) Unwrap() (proto.Message, error) { + switch msg := m.Sum.(type) { + case *Message_BlockRequest: + return m.GetBlockRequest(), nil + + case *Message_BlockResponse: + return m.GetBlockResponse(), nil + + case *Message_NoBlockResponse: + return m.GetNoBlockResponse(), nil + + case *Message_StatusRequest: + return m.GetStatusRequest(), nil + + case *Message_StatusResponse: + return m.GetStatusResponse(), nil + + default: + return nil, fmt.Errorf("unknown message: %T", msg) + } +} + +// Validate validates the message returning an error upon failure. +func (m *Message) Validate() error { + if m == nil { + return errors.New("message cannot be nil") + } + + switch msg := m.Sum.(type) { + case *Message_BlockRequest: + if m.GetBlockRequest().Height < 0 { + return errors.New("negative Height") + } + + case *Message_BlockResponse: + // validate basic is called later when converting from proto + return nil + + case *Message_NoBlockResponse: + if m.GetNoBlockResponse().Height < 0 { + return errors.New("negative Height") + } + + case *Message_StatusResponse: + if m.GetStatusResponse().Base < 0 { + return errors.New("negative Base") + } + if m.GetStatusResponse().Height < 0 { + return errors.New("negative Height") + } + if m.GetStatusResponse().Base > m.GetStatusResponse().Height { + return fmt.Errorf("base %v cannot be greater than height %v", m.GetStatusResponse().Base, m.GetStatusResponse().Height) + } + + case *Message_StatusRequest: + return nil + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + + return nil +} diff --git a/blockchain/msgs_test.go b/proto/tendermint/blockchain/message_test.go similarity index 76% rename from blockchain/msgs_test.go rename to proto/tendermint/blockchain/message_test.go index df8efca14..37a0df217 100644 --- a/blockchain/msgs_test.go +++ b/proto/tendermint/blockchain/message_test.go @@ -1,19 +1,18 @@ -package blockchain +package blockchain_test import ( "encoding/hex" - "math" + math "math" "testing" - "github.com/gogo/protobuf/proto" - "github.com/stretchr/testify/assert" + proto "github.com/gogo/protobuf/proto" "github.com/stretchr/testify/require" bcproto "github.com/tendermint/tendermint/proto/tendermint/blockchain" "github.com/tendermint/tendermint/types" ) -func TestBcBlockRequestMessageValidateBasic(t *testing.T) { +func TestBlockRequest_Validate(t *testing.T) { testCases := []struct { testName string requestHeight int64 @@ -27,13 +26,15 @@ func TestBcBlockRequestMessageValidateBasic(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.testName, func(t *testing.T) { - request := bcproto.BlockRequest{Height: tc.requestHeight} - assert.Equal(t, tc.expectErr, ValidateMsg(&request) != nil, "Validate Basic had an unexpected result") + msg := &bcproto.Message{} + require.NoError(t, msg.Wrap(&bcproto.BlockRequest{Height: tc.requestHeight})) + + require.Equal(t, tc.expectErr, msg.Validate() != nil) }) } } -func TestBcNoBlockResponseMessageValidateBasic(t *testing.T) { +func TestNoBlockResponse_Validate(t *testing.T) { testCases := []struct { testName string nonResponseHeight int64 @@ -47,18 +48,21 @@ func TestBcNoBlockResponseMessageValidateBasic(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.testName, func(t *testing.T) { - nonResponse := bcproto.NoBlockResponse{Height: tc.nonResponseHeight} - assert.Equal(t, tc.expectErr, ValidateMsg(&nonResponse) != nil, "Validate Basic had an unexpected result") + msg := &bcproto.Message{} + require.NoError(t, msg.Wrap(&bcproto.NoBlockResponse{Height: tc.nonResponseHeight})) + + require.Equal(t, tc.expectErr, msg.Validate() != nil) }) } } -func TestBcStatusRequestMessageValidateBasic(t *testing.T) { - request := bcproto.StatusRequest{} - assert.NoError(t, ValidateMsg(&request)) +func TestStatusRequest_Validate(t *testing.T) { + msg := &bcproto.Message{} + require.NoError(t, msg.Wrap(&bcproto.StatusRequest{})) + require.NoError(t, msg.Validate()) } -func TestBcStatusResponseMessageValidateBasic(t *testing.T) { +func TestStatusResponse_Validate(t *testing.T) { testCases := []struct { testName string responseHeight int64 @@ -72,13 +76,15 @@ func TestBcStatusResponseMessageValidateBasic(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.testName, func(t *testing.T) { - response := bcproto.StatusResponse{Height: tc.responseHeight} - assert.Equal(t, tc.expectErr, ValidateMsg(&response) != nil, "Validate Basic had an unexpected result") + msg := &bcproto.Message{} + require.NoError(t, msg.Wrap(&bcproto.StatusResponse{Height: tc.responseHeight})) + + require.Equal(t, tc.expectErr, msg.Validate() != nil) }) } } -// nolint:lll // ignore line length in tests +// nolint:lll func TestBlockchainMessageVectors(t *testing.T) { block := types.MakeBlock(int64(3), []types.Tx{types.Tx("Hello World")}, nil, nil) block.Version.Block = 11 // overwrite updated protocol version @@ -117,8 +123,8 @@ func TestBlockchainMessageVectors(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.testName, func(t *testing.T) { - bz, _ := proto.Marshal(tc.bmsg) - + bz, err := proto.Marshal(tc.bmsg) + require.NoError(t, err) require.Equal(t, tc.expBytes, hex.EncodeToString(bz)) }) } diff --git a/proto/tendermint/statesync/message.go b/proto/tendermint/statesync/message.go index 792e7f64c..fe38bda51 100644 --- a/proto/tendermint/statesync/message.go +++ b/proto/tendermint/statesync/message.go @@ -8,8 +8,8 @@ import ( ) // Wrap implements the p2p Wrapper interface and wraps a state sync messages. -func (m *Message) Wrap(msg proto.Message) error { - switch msg := msg.(type) { +func (m *Message) Wrap(pb proto.Message) error { + switch msg := pb.(type) { case *ChunkRequest: m.Sum = &Message_ChunkRequest{ChunkRequest: msg}