libs/bits: validate BitArray in FromProto (#5720)

Closes #5705
This commit is contained in:
Anton Kaliaev
2020-12-01 16:44:56 +04:00
committed by GitHub
parent 141d9c814d
commit b1bbd37519
5 changed files with 99 additions and 27 deletions

View File

@@ -2,7 +2,9 @@ package bits
import (
"encoding/binary"
"errors"
"fmt"
"math"
"regexp"
"strings"
"sync"
@@ -27,7 +29,7 @@ func NewBitArray(bits int) *BitArray {
}
return &BitArray{
Bits: bits,
Elems: make([]uint64, (bits+63)/64),
Elems: make([]uint64, numElems(bits)),
}
}
@@ -100,7 +102,7 @@ func (bA *BitArray) copy() *BitArray {
}
func (bA *BitArray) copyBits(bits int) *BitArray {
c := make([]uint64, (bits+63)/64)
c := make([]uint64, numElems(bits))
copy(c, bA.Elems)
return &BitArray{
Bits: bits,
@@ -418,27 +420,45 @@ func (bA *BitArray) UnmarshalJSON(bz []byte) error {
return nil
}
// ToProto converts BitArray to protobuf
// ToProto converts BitArray to protobuf. It returns nil if BitArray is
// nil/empty.
//
// XXX: It does not copy the array.
func (bA *BitArray) ToProto() *tmprotobits.BitArray {
if bA == nil || len(bA.Elems) == 0 {
if bA == nil ||
(len(bA.Elems) == 0 && bA.Bits == 0) { // empty
return nil
}
return &tmprotobits.BitArray{
Bits: int64(bA.Bits),
Elems: bA.Elems,
}
return &tmprotobits.BitArray{Bits: int64(bA.Bits), Elems: bA.Elems}
}
// FromProto sets a protobuf BitArray to the given pointer.
func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) {
// FromProto sets BitArray to the given protoBitArray. It returns an error if
// protoBitArray is invalid.
//
// XXX: It does not copy the array.
func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) error {
if protoBitArray == nil {
bA = nil
return
return nil
}
// Validate protoBitArray.
if protoBitArray.Bits < 0 {
return errors.New("negative Bits")
}
// #[32bit]
if protoBitArray.Bits > math.MaxInt32 { // prevent overflow on 32bit systems
return errors.New("too many Bits")
}
if got, exp := len(protoBitArray.Elems), numElems(int(protoBitArray.Bits)); got != exp {
return fmt.Errorf("invalid number of Elems: got %d, but exp %d", got, exp)
}
bA.Bits = int(protoBitArray.Bits)
if len(protoBitArray.Elems) > 0 {
bA.Elems = protoBitArray.Elems
}
bA.Elems = protoBitArray.Elems
return nil
}
func numElems(bits int) int {
return (bits + 63) / 64
}