diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index ef2586f6a..10b297be8 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -18,6 +18,7 @@ Friendly reminder, we have a [bug bounty program](https://hackerone.com/tendermi - Go API - [abci/client, proxy] \#5673 `Async` funcs return an error, `Sync` and `Async` funcs accept `context.Context` (@melekes) - [p2p] Removed unused function `MakePoWTarget`. (@erikgrinaker) + - [libs/bits] \#5720 Validate `BitArray` in `FromProto`, which now returns an error (@melekes) - [libs/os] Kill() and {Must,}{Read,Write}File() functions have been removed. (@alessio) diff --git a/consensus/msgs.go b/consensus/msgs.go index 4de96b5f4..91c091ba1 100644 --- a/consensus/msgs.go +++ b/consensus/msgs.go @@ -167,11 +167,14 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { case *tmcons.Message_NewValidBlock: pbPartSetHeader, err := types.PartSetHeaderFromProto(&msg.NewValidBlock.BlockPartSetHeader) if err != nil { - return nil, fmt.Errorf("parts to proto error: %w", err) + return nil, fmt.Errorf("parts header to proto error: %w", err) } pbBits := new(bits.BitArray) - pbBits.FromProto(msg.NewValidBlock.BlockParts) + err = pbBits.FromProto(msg.NewValidBlock.BlockParts) + if err != nil { + return nil, fmt.Errorf("parts to proto error: %w", err) + } pb = &NewValidBlockMessage{ Height: msg.NewValidBlock.Height, @@ -191,7 +194,10 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { } case *tmcons.Message_ProposalPol: pbBits := new(bits.BitArray) - pbBits.FromProto(&msg.ProposalPol.ProposalPol) + err := pbBits.FromProto(&msg.ProposalPol.ProposalPol) + if err != nil { + return nil, fmt.Errorf("proposal PoL to proto error: %w", err) + } pb = &ProposalPOLMessage{ Height: msg.ProposalPol.Height, ProposalPOLRound: msg.ProposalPol.ProposalPolRound, @@ -237,10 +243,13 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { case *tmcons.Message_VoteSetBits: bi, err := types.BlockIDFromProto(&msg.VoteSetBits.BlockID) if err != nil { - return nil, fmt.Errorf("voteSetBits msg to proto error: %w", err) + return nil, fmt.Errorf("block ID to proto error: %w", err) } bits := new(bits.BitArray) - bits.FromProto(&msg.VoteSetBits.Votes) + err = bits.FromProto(&msg.VoteSetBits.Votes) + if err != nil { + return nil, fmt.Errorf("votes to proto error: %w", err) + } pb = &VoteSetBitsMessage{ Height: msg.VoteSetBits.Height, diff --git a/libs/bits/bit_array.go b/libs/bits/bit_array.go index 9d6901460..1a41d87f9 100644 --- a/libs/bits/bit_array.go +++ b/libs/bits/bit_array.go @@ -2,7 +2,9 @@ package bits import ( "encoding/binary" + "errors" "fmt" + "math" "regexp" "strings" "sync" @@ -27,7 +29,7 @@ func NewBitArray(bits int) *BitArray { } return &BitArray{ Bits: bits, - Elems: make([]uint64, (bits+63)/64), + Elems: make([]uint64, numElems(bits)), } } @@ -100,7 +102,7 @@ func (bA *BitArray) copy() *BitArray { } func (bA *BitArray) copyBits(bits int) *BitArray { - c := make([]uint64, (bits+63)/64) + c := make([]uint64, numElems(bits)) copy(c, bA.Elems) return &BitArray{ Bits: bits, @@ -418,27 +420,45 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error { return nil } -// ToProto converts BitArray to protobuf +// ToProto converts BitArray to protobuf. It returns nil if BitArray is +// nil/empty. +// +// XXX: It does not copy the array. func (bA *BitArray) ToProto() *tmprotobits.BitArray { - if bA == nil || len(bA.Elems) == 0 { + if bA == nil || + (len(bA.Elems) == 0 && bA.Bits == 0) { // empty return nil } - return &tmprotobits.BitArray{ - Bits: int64(bA.Bits), - Elems: bA.Elems, - } + return &tmprotobits.BitArray{Bits: int64(bA.Bits), Elems: bA.Elems} } -// FromProto sets a protobuf BitArray to the given pointer. -func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) { +// FromProto sets BitArray to the given protoBitArray. It returns an error if +// protoBitArray is invalid. +// +// XXX: It does not copy the array. +func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) error { if protoBitArray == nil { - bA = nil - return + return nil + } + + // Validate protoBitArray. + if protoBitArray.Bits < 0 { + return errors.New("negative Bits") + } + // #[32bit] + if protoBitArray.Bits > math.MaxInt32 { // prevent overflow on 32bit systems + return errors.New("too many Bits") + } + if got, exp := len(protoBitArray.Elems), numElems(int(protoBitArray.Bits)); got != exp { + return fmt.Errorf("invalid number of Elems: got %d, but exp %d", got, exp) } bA.Bits = int(protoBitArray.Bits) - if len(protoBitArray.Elems) > 0 { - bA.Elems = protoBitArray.Elems - } + bA.Elems = protoBitArray.Elems + return nil +} + +func numElems(bits int) int { + return (bits + 63) / 64 } diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index e4306ecf2..10d607ef2 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -4,12 +4,14 @@ import ( "bytes" "encoding/json" "fmt" + "math" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" tmrand "github.com/tendermint/tendermint/libs/rand" + tmprotobits "github.com/tendermint/tendermint/proto/tendermint/libs/bits" ) func randBitArray(bits int) (*BitArray, []byte) { @@ -266,7 +268,7 @@ func TestJSONMarshalUnmarshal(t *testing.T) { } } -func TestBitArrayProtoBuf(t *testing.T) { +func TestBitArrayToFromProto(t *testing.T) { testCases := []struct { msg string bA1 *BitArray @@ -280,11 +282,41 @@ func TestBitArrayProtoBuf(t *testing.T) { for _, tc := range testCases { protoBA := tc.bA1.ToProto() ba := new(BitArray) - ba.FromProto(protoBA) + err := ba.FromProto(protoBA) if tc.expPass { + assert.NoError(t, err) require.Equal(t, tc.bA1, ba, tc.msg) } else { require.NotEqual(t, tc.bA1, ba, tc.msg) } } } + +func TestBitArrayFromProto(t *testing.T) { + testCases := []struct { + pbA *tmprotobits.BitArray + resA *BitArray + expErr bool + }{ + 0: {nil, &BitArray{}, false}, + 1: {&tmprotobits.BitArray{}, &BitArray{}, false}, + + 2: {&tmprotobits.BitArray{Bits: 1, Elems: make([]uint64, 1)}, &BitArray{Bits: 1, Elems: make([]uint64, 1)}, false}, + + 3: {&tmprotobits.BitArray{Bits: -1, Elems: make([]uint64, 1)}, &BitArray{}, true}, + 4: {&tmprotobits.BitArray{Bits: math.MaxInt32 + 1, Elems: make([]uint64, 1)}, &BitArray{}, true}, + 5: {&tmprotobits.BitArray{Bits: 1, Elems: make([]uint64, 2)}, &BitArray{}, true}, + } + + for i, tc := range testCases { + bA := new(BitArray) + err := bA.FromProto(tc.pbA) + if tc.expErr { + assert.Error(t, err, "#%d", i) + assert.Equal(t, tc.resA, bA, "#%d", i) + } else { + assert.NoError(t, err, "#%d", i) + assert.Equal(t, tc.resA, bA, "#%d", i) + } + } +} diff --git a/test/maverick/consensus/msgs.go b/test/maverick/consensus/msgs.go index 4de96b5f4..a1ac7c1a4 100644 --- a/test/maverick/consensus/msgs.go +++ b/test/maverick/consensus/msgs.go @@ -167,11 +167,14 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { case *tmcons.Message_NewValidBlock: pbPartSetHeader, err := types.PartSetHeaderFromProto(&msg.NewValidBlock.BlockPartSetHeader) if err != nil { - return nil, fmt.Errorf("parts to proto error: %w", err) + return nil, fmt.Errorf("parts header to proto error: %w", err) } pbBits := new(bits.BitArray) - pbBits.FromProto(msg.NewValidBlock.BlockParts) + err = pbBits.FromProto(msg.NewValidBlock.BlockParts) + if err != nil { + return nil, fmt.Errorf("parts to proto error: %w", err) + } pb = &NewValidBlockMessage{ Height: msg.NewValidBlock.Height, @@ -191,7 +194,11 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { } case *tmcons.Message_ProposalPol: pbBits := new(bits.BitArray) - pbBits.FromProto(&msg.ProposalPol.ProposalPol) + err := pbBits.FromProto(&msg.ProposalPol.ProposalPol) + if err != nil { + return nil, fmt.Errorf("proposal PoL to proto error: %w", err) + } + pb = &ProposalPOLMessage{ Height: msg.ProposalPol.Height, ProposalPOLRound: msg.ProposalPol.ProposalPolRound, @@ -237,10 +244,13 @@ func MsgFromProto(msg *tmcons.Message) (Message, error) { case *tmcons.Message_VoteSetBits: bi, err := types.BlockIDFromProto(&msg.VoteSetBits.BlockID) if err != nil { - return nil, fmt.Errorf("voteSetBits msg to proto error: %w", err) + return nil, fmt.Errorf("block ID msg to proto error: %w", err) } bits := new(bits.BitArray) - bits.FromProto(&msg.VoteSetBits.Votes) + err = bits.FromProto(&msg.VoteSetBits.Votes) + if err != nil { + return nil, fmt.Errorf("votes to proto error: %w", err) + } pb = &VoteSetBitsMessage{ Height: msg.VoteSetBits.Height,