blockchain v0: refactor message logic

This commit is contained in:
Aleksandr Bezobchuk
2021-01-04 16:32:40 -05:00
parent b48a7f4d7e
commit c111125b9f
4 changed files with 126 additions and 115 deletions

View File

@@ -1,12 +1,6 @@
package blockchain
import (
"errors"
"fmt"
"github.com/gogo/protobuf/proto"
bcproto "github.com/tendermint/tendermint/proto/tendermint/blockchain"
"github.com/tendermint/tendermint/types"
)
@@ -18,91 +12,3 @@ const (
BlockResponseMessagePrefixSize +
BlockResponseMessageFieldKeySize
)
// EncodeMsg encodes a Protobuf message
func EncodeMsg(pb proto.Message) ([]byte, error) {
msg := bcproto.Message{}
switch pb := pb.(type) {
case *bcproto.BlockRequest:
msg.Sum = &bcproto.Message_BlockRequest{BlockRequest: pb}
case *bcproto.BlockResponse:
msg.Sum = &bcproto.Message_BlockResponse{BlockResponse: pb}
case *bcproto.NoBlockResponse:
msg.Sum = &bcproto.Message_NoBlockResponse{NoBlockResponse: pb}
case *bcproto.StatusRequest:
msg.Sum = &bcproto.Message_StatusRequest{StatusRequest: pb}
case *bcproto.StatusResponse:
msg.Sum = &bcproto.Message_StatusResponse{StatusResponse: pb}
default:
return nil, fmt.Errorf("unknown message type %T", pb)
}
bz, err := proto.Marshal(&msg)
if err != nil {
return nil, fmt.Errorf("unable to marshal %T: %w", pb, err)
}
return bz, nil
}
// DecodeMsg decodes a Protobuf message.
func DecodeMsg(bz []byte) (proto.Message, error) {
pb := &bcproto.Message{}
err := proto.Unmarshal(bz, pb)
if err != nil {
return nil, err
}
switch msg := pb.Sum.(type) {
case *bcproto.Message_BlockRequest:
return msg.BlockRequest, nil
case *bcproto.Message_BlockResponse:
return msg.BlockResponse, nil
case *bcproto.Message_NoBlockResponse:
return msg.NoBlockResponse, nil
case *bcproto.Message_StatusRequest:
return msg.StatusRequest, nil
case *bcproto.Message_StatusResponse:
return msg.StatusResponse, nil
default:
return nil, fmt.Errorf("unknown message type %T", msg)
}
}
// ValidateMsg validates a message.
func ValidateMsg(pb proto.Message) error {
if pb == nil {
return errors.New("message cannot be nil")
}
switch msg := pb.(type) {
case *bcproto.BlockRequest:
if msg.Height < 0 {
return errors.New("negative Height")
}
case *bcproto.BlockResponse:
// validate basic is called later when converting from proto
return nil
case *bcproto.NoBlockResponse:
if msg.Height < 0 {
return errors.New("negative Height")
}
case *bcproto.StatusResponse:
if msg.Base < 0 {
return errors.New("negative Base")
}
if msg.Height < 0 {
return errors.New("negative Height")
}
if msg.Base > msg.Height {
return fmt.Errorf("base %v cannot be greater than height %v", msg.Base, msg.Height)
}
case *bcproto.StatusRequest:
return nil
default:
return fmt.Errorf("unknown message type %T", msg)
}
return nil
}

View File

@@ -0,0 +1,99 @@
package blockchain
import (
"errors"
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
)
// Wrap implements the p2p Wrapper interface and wraps a blockchain messages.
func (m *Message) Wrap(pb proto.Message) error {
switch msg := pb.(type) {
case *BlockRequest:
m.Sum = &Message_BlockRequest{BlockRequest: msg}
case *BlockResponse:
m.Sum = &Message_BlockResponse{BlockResponse: msg}
case *NoBlockResponse:
m.Sum = &Message_NoBlockResponse{NoBlockResponse: msg}
case *StatusRequest:
m.Sum = &Message_StatusRequest{StatusRequest: msg}
case *StatusResponse:
m.Sum = &Message_StatusResponse{StatusResponse: msg}
default:
return fmt.Errorf("unknown message: %T", msg)
}
return nil
}
// Unwrap implements the p2p Wrapper interface and unwraps a wrapped blockchain
// message.
func (m *Message) Unwrap() (proto.Message, error) {
switch msg := m.Sum.(type) {
case *Message_BlockRequest:
return m.GetBlockRequest(), nil
case *Message_BlockResponse:
return m.GetBlockResponse(), nil
case *Message_NoBlockResponse:
return m.GetNoBlockResponse(), nil
case *Message_StatusRequest:
return m.GetStatusRequest(), nil
case *Message_StatusResponse:
return m.GetStatusResponse(), nil
default:
return nil, fmt.Errorf("unknown message: %T", msg)
}
}
// Validate validates the message returning an error upon failure.
func (m *Message) Validate() error {
if m == nil {
return errors.New("message cannot be nil")
}
switch msg := m.Sum.(type) {
case *Message_BlockRequest:
if m.GetBlockRequest().Height < 0 {
return errors.New("negative Height")
}
case *Message_BlockResponse:
// validate basic is called later when converting from proto
return nil
case *Message_NoBlockResponse:
if m.GetNoBlockResponse().Height < 0 {
return errors.New("negative Height")
}
case *Message_StatusResponse:
if m.GetStatusResponse().Base < 0 {
return errors.New("negative Base")
}
if m.GetStatusResponse().Height < 0 {
return errors.New("negative Height")
}
if m.GetStatusResponse().Base > m.GetStatusResponse().Height {
return fmt.Errorf("base %v cannot be greater than height %v", m.GetStatusResponse().Base, m.GetStatusResponse().Height)
}
case *Message_StatusRequest:
return nil
default:
return fmt.Errorf("unknown message type: %T", msg)
}
return nil
}

View File

@@ -1,19 +1,18 @@
package blockchain
package blockchain_test
import (
"encoding/hex"
"math"
math "math"
"testing"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert"
proto "github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/require"
bcproto "github.com/tendermint/tendermint/proto/tendermint/blockchain"
"github.com/tendermint/tendermint/types"
)
func TestBcBlockRequestMessageValidateBasic(t *testing.T) {
func TestBlockRequest_Validate(t *testing.T) {
testCases := []struct {
testName string
requestHeight int64
@@ -27,13 +26,15 @@ func TestBcBlockRequestMessageValidateBasic(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.testName, func(t *testing.T) {
request := bcproto.BlockRequest{Height: tc.requestHeight}
assert.Equal(t, tc.expectErr, ValidateMsg(&request) != nil, "Validate Basic had an unexpected result")
msg := &bcproto.Message{}
require.NoError(t, msg.Wrap(&bcproto.BlockRequest{Height: tc.requestHeight}))
require.Equal(t, tc.expectErr, msg.Validate() != nil)
})
}
}
func TestBcNoBlockResponseMessageValidateBasic(t *testing.T) {
func TestNoBlockResponse_Validate(t *testing.T) {
testCases := []struct {
testName string
nonResponseHeight int64
@@ -47,18 +48,21 @@ func TestBcNoBlockResponseMessageValidateBasic(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.testName, func(t *testing.T) {
nonResponse := bcproto.NoBlockResponse{Height: tc.nonResponseHeight}
assert.Equal(t, tc.expectErr, ValidateMsg(&nonResponse) != nil, "Validate Basic had an unexpected result")
msg := &bcproto.Message{}
require.NoError(t, msg.Wrap(&bcproto.NoBlockResponse{Height: tc.nonResponseHeight}))
require.Equal(t, tc.expectErr, msg.Validate() != nil)
})
}
}
func TestBcStatusRequestMessageValidateBasic(t *testing.T) {
request := bcproto.StatusRequest{}
assert.NoError(t, ValidateMsg(&request))
func TestStatusRequest_Validate(t *testing.T) {
msg := &bcproto.Message{}
require.NoError(t, msg.Wrap(&bcproto.StatusRequest{}))
require.NoError(t, msg.Validate())
}
func TestBcStatusResponseMessageValidateBasic(t *testing.T) {
func TestStatusResponse_Validate(t *testing.T) {
testCases := []struct {
testName string
responseHeight int64
@@ -72,13 +76,15 @@ func TestBcStatusResponseMessageValidateBasic(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.testName, func(t *testing.T) {
response := bcproto.StatusResponse{Height: tc.responseHeight}
assert.Equal(t, tc.expectErr, ValidateMsg(&response) != nil, "Validate Basic had an unexpected result")
msg := &bcproto.Message{}
require.NoError(t, msg.Wrap(&bcproto.StatusResponse{Height: tc.responseHeight}))
require.Equal(t, tc.expectErr, msg.Validate() != nil)
})
}
}
// nolint:lll // ignore line length in tests
// nolint:lll
func TestBlockchainMessageVectors(t *testing.T) {
block := types.MakeBlock(int64(3), []types.Tx{types.Tx("Hello World")}, nil, nil)
block.Version.Block = 11 // overwrite updated protocol version
@@ -117,8 +123,8 @@ func TestBlockchainMessageVectors(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.testName, func(t *testing.T) {
bz, _ := proto.Marshal(tc.bmsg)
bz, err := proto.Marshal(tc.bmsg)
require.NoError(t, err)
require.Equal(t, tc.expBytes, hex.EncodeToString(bz))
})
}

View File

@@ -8,8 +8,8 @@ import (
)
// Wrap implements the p2p Wrapper interface and wraps a state sync messages.
func (m *Message) Wrap(msg proto.Message) error {
switch msg := msg.(type) {
func (m *Message) Wrap(pb proto.Message) error {
switch msg := pb.(type) {
case *ChunkRequest:
m.Sum = &Message_ChunkRequest{ChunkRequest: msg}