From bd266494554dc376cb7572fb0d1794eca14b9d34 Mon Sep 17 00:00:00 2001 From: William Banfield Date: Mon, 16 May 2022 11:24:43 -0400 Subject: [PATCH] change voteset constructor to strict vs non strict --- internal/consensus/reactor_test.go | 7 +++- internal/consensus/state.go | 7 +++- internal/consensus/types/height_vote_set.go | 9 +++-- internal/evidence/verify_test.go | 12 +++---- internal/statesync/reactor_test.go | 2 +- node/node_test.go | 2 +- test/e2e/runner/evidence.go | 2 +- types/block.go | 23 +++++++++--- types/block_test.go | 14 ++++++-- types/vote_set.go | 40 +++++++++++++-------- types/vote_set_test.go | 11 ++++-- 11 files changed, 93 insertions(+), 36 deletions(-) diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index a473f0c28..f300a0719 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -667,7 +667,12 @@ func TestSwitchToConsensusVoteExtensions(t *testing.T) { blockParts, err := propBlock.MakePartSet(types.BlockPartSizeBytes) require.NoError(t, err) - voteSet := types.NewVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators, testCase.includeExtensions) + var voteSet *types.VoteSet + if testCase.includeExtensions { + voteSet = types.NewStrictVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } else { + voteSet = types.NewVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } signedVote := signVote(ctx, t, validator, tmproto.PrecommitType, cs.state.ChainID, types.BlockID{ Hash: propBlock.Hash(), PartSetHeader: blockParts.Header(), diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 407724cc0..ed04deb09 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -718,7 +718,12 @@ func (cs *State) votesFromExtendedCommit(state sm.State, requireExtensions bool) if ec == nil { return nil, fmt.Errorf("commit for height %v not found", state.LastBlockHeight) } - vs := ec.ToVoteSet(state.ChainID, state.LastValidators, requireExtensions) + var vs *types.VoteSet + if requireExtensions { + vs = ec.ToStrictVoteSet(state.ChainID, state.LastValidators) + } else { + vs = ec.ToVoteSet(state.ChainID, state.LastValidators) + } if !vs.HasTwoThirdsMajority() { return nil, errors.New("extended commit does not have +2/3 majority") } diff --git a/internal/consensus/types/height_vote_set.go b/internal/consensus/types/height_vote_set.go index 0c432e80d..854d9b1f1 100644 --- a/internal/consensus/types/height_vote_set.go +++ b/internal/consensus/types/height_vote_set.go @@ -109,8 +109,13 @@ func (hvs *HeightVoteSet) addRound(round int32) { panic("addRound() for an existing round") } // log.Debug("addRound(round)", "round", round) - prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet, hvs.requireExtensions) - precommits := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet, hvs.requireExtensions) + prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet) + var precommits *types.VoteSet + if hvs.requireExtensions { + precommits = types.NewStrictVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } else { + precommits = types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } hvs.roundVoteSets[round] = RoundVoteSet{ Prevotes: prevotes, Precommits: precommits, diff --git a/internal/evidence/verify_test.go b/internal/evidence/verify_test.go index c811125e9..2ed84fa69 100644 --- a/internal/evidence/verify_test.go +++ b/internal/evidence/verify_test.go @@ -233,7 +233,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { // we are simulating a duplicate vote attack where all the validators in the conflictingVals set // except the last validator vote twice blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals, false) + voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, 10, 1, voteSet, conflictingPrivVals[:4], defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -253,7 +253,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals, false) + trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, 10, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) @@ -336,7 +336,7 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { // we are simulating an amnesia attack where all the validators in the conflictingVals set // except the last validator vote twice. However this time the commits are of different rounds. blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals, false) + voteSet := types.NewVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 0, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -356,7 +356,7 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals, false) + trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) @@ -553,7 +553,7 @@ func makeLunaticEvidence( }) blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals, false) + voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 1, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -582,7 +582,7 @@ func makeLunaticEvidence( } trustedBlockID := factory.MakeBlockIDWithHash(trustedHeader.Hash()) trustedVals, privVals := factory.ValidatorSet(ctx, t, totalVals, defaultVotingPower) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals, false) + trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, privVals, defaultEvidenceTime) require.NoError(t, err) trustedCommit := trustedExtCommit.StripExtensions() diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 38a829f8d..904fb2b74 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -855,7 +855,7 @@ func mockLB(ctx context.Context, t *testing.T, height int64, time time.Time, las header.NextValidatorsHash = nextVals.Hash() header.ConsensusHash = types.DefaultConsensusParams().HashConsensusParams() lastBlockID = factory.MakeBlockIDWithHash(header.Hash()) - voteSet := types.NewVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals, false) + voteSet := types.NewVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals) extCommit, err := factory.MakeExtendedCommit(ctx, lastBlockID, height, 0, voteSet, currentPrivVals, time) require.NoError(t, err) return nextVals, nextPrivVals, &types.LightBlock{ diff --git a/node/node_test.go b/node/node_test.go index 2386f6884..b1d7a9481 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -526,7 +526,7 @@ func TestMaxProposalBlockSize(t *testing.T) { } state.ChainID = maxChainID - voteSet := types.NewVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators, false) + voteSet := types.NewVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators) // add maximum amount of signatures to a single commit for i := 0; i < types.MaxVotesCount; i++ { diff --git a/test/e2e/runner/evidence.go b/test/e2e/runner/evidence.go index 551906c1e..a71ea14fb 100644 --- a/test/e2e/runner/evidence.go +++ b/test/e2e/runner/evidence.go @@ -165,7 +165,7 @@ func generateLightClientAttackEvidence( // create a commit for the forged header blockID := makeBlockID(header.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals, false) + voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) commit, err := factory.MakeExtendedCommit(ctx, blockID, forgedHeight, 0, voteSet, pv, forgedTime) if err != nil { diff --git a/types/block.go b/types/block.go index b99d8a12a..ab79d59ff 100644 --- a/types/block.go +++ b/types/block.go @@ -1016,11 +1016,27 @@ func (ec *ExtendedCommit) Clone() *ExtendedCommit { return &ecc } +// ToStrictVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. +// Panics if any of the votes have invalid or absent vote extension data. +// Inverse of VoteSet.MakeExtendedCommit(). +func (ec *ExtendedCommit) ToStrictVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewStrictVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + // ToVoteSet constructs a VoteSet from the Commit and validator set. // Panics if signatures from the ExtendedCommit can't be added to the voteset. // Inverse of VoteSet.MakeExtendedCommit(). -func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet, requireExtensions bool) *VoteSet { - voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals, requireExtensions) +func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + +// addSigsToVoteSet adds all of the signature to voteSet. +func (ec *ExtendedCommit) addSigsToVoteSet(voteSet *VoteSet) { for idx, ecs := range ec.ExtendedSignatures { if ecs.BlockIDFlag == BlockIDFlagAbsent { continue // OK, some precommits can be missing. @@ -1034,14 +1050,13 @@ func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet, requireE panic(fmt.Errorf("failed to reconstruct vote set from extended commit: %w", err)) } } - return voteSet } // ToVoteSet constructs a VoteSet from the Commit and validator set. // Panics if signatures from the commit can't be added to the voteset. // Inverse of VoteSet.MakeCommit(). func (commit *Commit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { - voteSet := NewVoteSet(chainID, commit.Height, commit.Round, tmproto.PrecommitType, vals, false) + voteSet := NewVoteSet(chainID, commit.Height, commit.Round, tmproto.PrecommitType, vals) for idx, cs := range commit.Signatures { if cs.BlockIDFlag == BlockIDFlagAbsent { continue // OK, some precommits can be missing. diff --git a/types/block_test.go b/types/block_test.go index 1ebca3ae4..365189049 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -581,7 +581,12 @@ func TestVoteSetToExtendedCommit(t *testing.T) { defer cancel() valSet, vals := randValidatorPrivValSet(ctx, t, 10, 1) - voteSet := NewVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet, testCase.includeExtension) + var voteSet *VoteSet + if testCase.includeExtension { + voteSet = NewStrictVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } for i := 0; i < len(vals); i++ { pubKey, err := vals[i].GetPubKey(ctx) require.NoError(t, err) @@ -661,7 +666,12 @@ func TestExtendedCommitToVoteSet(t *testing.T) { } chainID := voteSet.ChainID() - voteSet2 := extCommit.ToVoteSet(chainID, valSet, testCase.includeExtension) + var voteSet2 *VoteSet + if testCase.includeExtension { + voteSet2 = extCommit.ToStrictVoteSet(chainID, valSet) + } else { + voteSet2 = extCommit.ToVoteSet(chainID, valSet) + } for i := int32(0); int(i) < len(vals); i++ { vote1 := voteSet.GetByIndex(i) diff --git a/types/vote_set.go b/types/vote_set.go index 0905d651d..d919360c5 100644 --- a/types/vote_set.go +++ b/types/vote_set.go @@ -69,28 +69,40 @@ type VoteSet struct { peerMaj23s map[string]BlockID // Maj23 for each peer } -// Constructs a new VoteSet struct used to accumulate votes for given height/round. +// NewVoteSet instantiates all fields of a new vote set. This constructor does +// not require that the votes added to the set contain vote extension data, but +// if votes are added that contain extension data, the extension data will be +// verified. func NewVoteSet(chainID string, height int64, round int32, - signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet, requireExtensions bool) *VoteSet { + signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { if height == 0 { panic("Cannot make VoteSet for height == 0, doesn't make sense.") } return &VoteSet{ - chainID: chainID, - height: height, - round: round, - signedMsgType: signedMsgType, - valSet: valSet, - requireExtensions: requireExtensions, - votesBitArray: bits.NewBitArray(valSet.Size()), - votes: make([]*Vote, valSet.Size()), - sum: 0, - maj23: nil, - votesByBlock: make(map[string]*blockVotes, valSet.Size()), - peerMaj23s: make(map[string]BlockID), + chainID: chainID, + height: height, + round: round, + signedMsgType: signedMsgType, + valSet: valSet, + votesBitArray: bits.NewBitArray(valSet.Size()), + votes: make([]*Vote, valSet.Size()), + sum: 0, + maj23: nil, + votesByBlock: make(map[string]*blockVotes, valSet.Size()), + peerMaj23s: make(map[string]BlockID), } } +// NewStrictVoteSet constructs a vote set with additional vote verification logic. +// The VoteSet constructed with NewStrictVoteSet verifies the vote extension +// data for every vote added to the set. +func NewStrictVoteSet(chainID string, height int64, round int32, + signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { + vs := NewStrictVoteSet(chainID, height, round, signedMsgType, valSet) + vs.requireExtensions = true + return vs +} + func (voteSet *VoteSet) ChainID() string { return voteSet.chainID } diff --git a/types/vote_set_test.go b/types/vote_set_test.go index b80251413..69c7f9d9c 100644 --- a/types/vote_set_test.go +++ b/types/vote_set_test.go @@ -532,7 +532,12 @@ func TestVoteSet_RequireExtensions(t *testing.T) { height, round := int64(1), int32(0) valSet, privValidators := randValidatorPrivValSet(ctx, t, 5, 10) - voteSet := NewVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet, tc.requireExtensions) + var voteSet *VoteSet + if tc.requireExtensions { + voteSet = NewStrictVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } val0 := privValidators[0] @@ -585,7 +590,7 @@ func randVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := randValidatorPrivValSet(ctx, t, numValidators, votingPower) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet, true), valSet, privValidators + return NewStrictVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators } func deterministicVoteSet( @@ -598,7 +603,7 @@ func deterministicVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := deterministicValidatorSet(ctx, t) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet, true), valSet, privValidators + return NewStrictVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators } func randValidatorPrivValSet(ctx context.Context, t testing.TB, numValidators int, votingPower int64) (*ValidatorSet, []PrivValidator) {