diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index 144595889..1902780d6 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -557,7 +557,10 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh // validate the block before we persist it err = r.blockExec.ValidateBlock(ctx, state, first) } - + if err == nil && state.ConsensusParams.Vote.RequireExtensions(extCommit.Height) { + // if vote extensions were required at this height, ensure they exist. + err = extCommit.EnsureExtensions() + } // If either of the checks failed we log the error and request for a new block // at that height if err != nil { diff --git a/internal/consensus/msgs.go b/internal/consensus/msgs.go index c59c06a41..1024c24ae 100644 --- a/internal/consensus/msgs.go +++ b/internal/consensus/msgs.go @@ -222,11 +222,7 @@ func (*VoteMessage) TypeTag() string { return "tendermint/Vote" } // ValidateBasic checks whether the vote within the message is well-formed. func (m *VoteMessage) ValidateBasic() error { - // Here we validate votes with vote extensions, since we require vote - // extensions to be sent in precommit messages during consensus. Prevote - // messages should never have vote extensions, and this is also validated - // here. - return m.Vote.ValidateWithExtension() + return m.Vote.ValidateBasic() } // String returns a string representation. diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index c6a8869db..7fa514e3f 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -32,6 +32,7 @@ import ( "github.com/tendermint/tendermint/internal/test/factory" "github.com/tendermint/tendermint/libs/log" tmcons "github.com/tendermint/tendermint/proto/tendermint/consensus" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/tendermint/tendermint/types" ) @@ -600,6 +601,106 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { wg.Wait() } +func TestSwitchToConsensusVoteExtensions(t *testing.T) { + for _, testCase := range []struct { + name string + storedHeight int64 + initialRequiredHeight int64 + includeExtensions bool + shouldPanic bool + }{ + { + name: "no vote extensions but not required", + initialRequiredHeight: 0, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions but required this height", + initialRequiredHeight: 2, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "no vote extensions and required in future", + initialRequiredHeight: 3, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: true, + shouldPanic: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + cs, vs := makeState(ctx, t, makeStateArgs{validators: 1}) + validator := vs[0] + validator.Height = testCase.storedHeight + + cs.state.LastBlockHeight = testCase.storedHeight + cs.state.LastValidators = cs.state.Validators.Copy() + cs.state.ConsensusParams.Vote.ExtensionRequireHeight = testCase.initialRequiredHeight + + propBlock, err := cs.createProposalBlock(ctx) + require.NoError(t, err) + + // Consensus is preparing to do the next height after the stored height. + cs.Height = testCase.storedHeight + 1 + propBlock.Height = testCase.storedHeight + blockParts, err := propBlock.MakePartSet(types.BlockPartSizeBytes) + require.NoError(t, err) + + voteSet := types.NewVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators, testCase.includeExtensions) + signedVote := signVote(ctx, t, validator, tmproto.PrecommitType, cs.state.ChainID, types.BlockID{ + Hash: propBlock.Hash(), + PartSetHeader: blockParts.Header(), + }) + + if !testCase.includeExtensions { + signedVote.Extension = nil + signedVote.ExtensionSignature = nil + } + + added, err := voteSet.AddVote(signedVote) + require.NoError(t, err) + require.True(t, added) + cs.blockStore.SaveBlock(propBlock, blockParts, voteSet.MakeExtendedCommit()) + reactor := NewReactor( + log.NewNopLogger(), + cs, + nil, + nil, + cs.eventBus, + true, + NopMetrics(), + ) + + if testCase.shouldPanic { + assert.Panics(t, func() { + reactor.SwitchToConsensus(ctx, cs.state, false) + }) + } else { + reactor.SwitchToConsensus(ctx, cs.state, false) + } + }) + } +} + func TestReactorRecordsVotesAndBlockParts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/internal/consensus/state.go b/internal/consensus/state.go index b016e2687..407724cc0 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -692,23 +692,53 @@ func (cs *State) sendInternalMessage(ctx context.Context, mi msgInfo) { } } -// Reconstruct LastCommit from SeenCommit, which we saved along with the block, -// (which happens even before saving the state) +// Reconstruct LastCommit from either SeenCommit or the ExtendedCommit. SeenCommit +// and ExtendedCommit are saved along with the block. If VoteExtensions are required +// the method will panic on an absent ExtendedCommit or an ExtendedCommit without +// extension data. func (cs *State) reconstructLastCommit(state sm.State) { - extCommit := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) - if extCommit == nil { - panic(fmt.Sprintf( - "failed to reconstruct last commit; commit for height %v not found", - state.LastBlockHeight, - )) + requireExtensions := cs.state.ConsensusParams.Vote.RequireExtensions(state.LastBlockHeight) + votes, err := cs.votesFromExtendedCommit(state, requireExtensions) + if err == nil { + cs.LastCommit = votes + return + } + if requireExtensions { + panic(fmt.Sprintf("failed to reconstruct last commit; %s", err)) + } + votes, err = cs.votesFromSeenCommit(state) + if err != nil { + panic(fmt.Sprintf("failed to reconstruct last commit; %s", err)) + } + cs.LastCommit = votes +} + +func (cs *State) votesFromExtendedCommit(state sm.State, requireExtensions bool) (*types.VoteSet, error) { + ec := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) + if ec == nil { + return nil, fmt.Errorf("commit for height %v not found", state.LastBlockHeight) + } + vs := ec.ToVoteSet(state.ChainID, state.LastValidators, requireExtensions) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("extended commit does not have +2/3 majority") + } + return vs, nil +} + +func (cs *State) votesFromSeenCommit(state sm.State) (*types.VoteSet, error) { + commit := cs.blockStore.LoadSeenCommit() + if commit == nil || commit.Height != state.LastBlockHeight { + commit = cs.blockStore.LoadBlockCommit(state.LastBlockHeight) + } + if commit == nil { + return nil, fmt.Errorf("commit for height %v not found", state.LastBlockHeight) } - lastPrecommits := extCommit.ToVoteSet(state.ChainID, state.LastValidators) - if !lastPrecommits.HasTwoThirdsMajority() { - panic("failed to reconstruct last commit; does not have +2/3 maj") + vs := commit.ToVoteSet(state.ChainID, state.LastValidators) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("commit does not have +2/3 majority") } - - cs.LastCommit = lastPrecommits + return vs, nil } // Updates State and increments height to match that of state. @@ -810,7 +840,8 @@ func (cs *State) updateToState(state sm.State) { cs.ValidRound = -1 cs.ValidBlock = nil cs.ValidBlockParts = nil - cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators) + requireExtensions := state.ConsensusParams.Vote.RequireExtensions(height) + cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators, requireExtensions) cs.CommitRound = -1 cs.LastValidators = state.LastValidators cs.TriggeredTimeoutPrecommit = false @@ -2337,13 +2368,37 @@ func (cs *State) addVote( return } + var addr []byte + if cs.privValidatorPubKey != nil { + addr = cs.privValidatorPubKey.Address() + } // Verify VoteExtension if precommit and not nil // https://github.com/tendermint/tendermint/issues/8487 - if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() { - err := cs.blockExec.VerifyVoteExtension(ctx, vote) - cs.metrics.MarkVoteExtensionReceived(err == nil) - if err != nil { - return false, err + if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() && + !bytes.Equal(vote.ValidatorAddress, addr) { + // The core fields of the vote message were already validated in the + // consensus reactor when the vote was received. + // Here, we valdiate that the vote extension was included in the vote + // message. + // Chains that are not configured to require vote extensions + // will consider the vote valid even if the extension is absent. + // VerifyVoteExtension will not be called in this case if the extension + // is absent. + err := vote.EnsureExtension() + if err == nil { + _, val := cs.state.Validators.GetByIndex(vote.ValidatorIndex) + err = vote.VerifyExtension(cs.state.ChainID, val.PubKey) + } + if err == nil { + err := cs.blockExec.VerifyVoteExtension(ctx, vote) + cs.metrics.MarkVoteExtensionReceived(err == nil) + } else { + if !errors.Is(err, types.ErrVoteExtensionAbsent) { + return false, err + } + if cs.state.ConsensusParams.Vote.RequireExtensions(cs.Height) { + return false, err + } } } diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 6fa69a1a3..eea671a58 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -2076,19 +2076,13 @@ func TestExtendVoteCalled(t *testing.T) { Hash: blockID.Hash, }) - m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) ensureNewRound(t, newRoundCh, height+1, 0) m.AssertExpectations(t) // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold // is observed by the consensus engine. - for _, pv := range vss[:3] { + for _, pv := range vss[1:3] { pv, err := pv.GetPubKey(ctx) require.NoError(t, err) addr := pv.Address() @@ -2148,13 +2142,6 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { Hash: blockID.Hash, }) - m.AssertCalled(t, "VerifyVoteExtension", mock.Anything, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) - m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[2:]...) ensureNewRound(t, newRoundCh, height+1, 0) @@ -2266,6 +2253,118 @@ func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { } } +// TestVoteExtensionRequiredHeight tests that 'ExtensionRequireHeight' correctly +// enforces that vote extensions be present in consensus for heights greater than +// or equal to the configured value. +func TestVoteExtensionRequiredHeight(t *testing.T) { + for _, testCase := range []struct { + name string + initialRequiredHeight int64 + hasExtension bool + expectSuccessfulRound bool + }{ + { + name: "extension present but not required", + hasExtension: true, + initialRequiredHeight: 0, + expectSuccessfulRound: true, + }, + { + name: "extension absent but not required", + hasExtension: false, + initialRequiredHeight: 0, + expectSuccessfulRound: true, + }, + { + name: "extension present and required", + hasExtension: true, + initialRequiredHeight: 1, + expectSuccessfulRound: true, + }, + { + name: "extension absent but required", + hasExtension: false, + initialRequiredHeight: 1, + expectSuccessfulRound: false, + }, + { + name: "extension absent but required in future height", + hasExtension: false, + initialRequiredHeight: 2, + expectSuccessfulRound: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + numValidators := 3 + m := abcimocks.NewApplication(t) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{ + Status: abci.ResponseProcessProposal_ACCEPT, + }, nil) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{}, nil) + if testCase.hasExtension { + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil).Times(numValidators - 1) + } + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m, validators: numValidators}) + cs1.state.ConsensusParams.Vote.ExtensionRequireHeight = testCase.initialRequiredHeight + height, round := cs1.Height, cs1.Round + + timeoutCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) + + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + rs := cs1.GetRoundState() + + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + + // sign all of the votes + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, rs.ProposalBlock.Hash()) + + var ext []byte + if testCase.hasExtension { + ext = []byte("extension") + } + + for _, vs := range vss[1:] { + vote, err := vs.signVote(ctx, tmproto.PrecommitType, config.ChainID(), blockID, ext) + if !testCase.hasExtension { + vote.ExtensionSignature = nil + } + require.NoError(t, err) + addVotes(cs1, vote) + } + if testCase.expectSuccessfulRound { + ensurePrecommit(t, voteCh, height, round) + height++ + ensureNewRound(t, newRoundCh, height, round) + } else { + ensureNoNewTimeout(t, timeoutCh, cs1.state.ConsensusParams.Timeout.VoteTimeout(round).Nanoseconds()) + } + + m.AssertExpectations(t) + }) + } +} + // 4 vals, 3 Nil Precommits at P0 // What we want: // P0 waits for timeoutPrecommit before starting next round diff --git a/internal/consensus/types/height_vote_set.go b/internal/consensus/types/height_vote_set.go index 661c5120e..0c432e80d 100644 --- a/internal/consensus/types/height_vote_set.go +++ b/internal/consensus/types/height_vote_set.go @@ -38,9 +38,10 @@ We let each peer provide us with up to 2 unexpected "catchup" rounds. One for their LastCommit round, and another for the official commit round. */ type HeightVoteSet struct { - chainID string - height int64 - valSet *types.ValidatorSet + chainID string + height int64 + valSet *types.ValidatorSet + requireExtensions bool mtx sync.Mutex round int32 // max tracked round @@ -48,9 +49,10 @@ type HeightVoteSet struct { peerCatchupRounds map[types.NodeID][]int32 // keys: peer.ID; values: at most 2 rounds } -func NewHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { +func NewHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet, requireExtensions bool) *HeightVoteSet { hvs := &HeightVoteSet{ - chainID: chainID, + chainID: chainID, + requireExtensions: requireExtensions, } hvs.Reset(height, valSet) return hvs @@ -107,8 +109,8 @@ func (hvs *HeightVoteSet) addRound(round int32) { panic("addRound() for an existing round") } // log.Debug("addRound(round)", "round", round) - prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet) - precommits := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet, hvs.requireExtensions) + precommits := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet, hvs.requireExtensions) hvs.roundVoteSets[round] = RoundVoteSet{ Prevotes: prevotes, Precommits: precommits, diff --git a/internal/consensus/types/height_vote_set_test.go b/internal/consensus/types/height_vote_set_test.go index acffa794c..b21895409 100644 --- a/internal/consensus/types/height_vote_set_test.go +++ b/internal/consensus/types/height_vote_set_test.go @@ -27,7 +27,7 @@ func TestPeerCatchupRounds(t *testing.T) { valSet, privVals := factory.ValidatorSet(ctx, t, 10, 1) chainID := cfg.ChainID() - hvs := NewHeightVoteSet(chainID, 1, valSet) + hvs := NewHeightVoteSet(chainID, 1, valSet, false) vote999_0 := makeVoteHR(ctx, t, 1, 0, 999, privVals, chainID) added, err := hvs.AddVote(vote999_0, "peer1") diff --git a/internal/evidence/verify_test.go b/internal/evidence/verify_test.go index 2ed84fa69..c811125e9 100644 --- a/internal/evidence/verify_test.go +++ b/internal/evidence/verify_test.go @@ -233,7 +233,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { // we are simulating a duplicate vote attack where all the validators in the conflictingVals set // except the last validator vote twice blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals, false) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, 10, 1, voteSet, conflictingPrivVals[:4], defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -253,7 +253,7 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals, false) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, 10, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) @@ -336,7 +336,7 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { // we are simulating an amnesia attack where all the validators in the conflictingVals set // except the last validator vote twice. However this time the commits are of different rounds. blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals, false) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 0, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -356,7 +356,7 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals, false) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) @@ -553,7 +553,7 @@ func makeLunaticEvidence( }) blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals, false) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 1, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) commit := extCommit.StripExtensions() @@ -582,7 +582,7 @@ func makeLunaticEvidence( } trustedBlockID := factory.MakeBlockIDWithHash(trustedHeader.Hash()) trustedVals, privVals := factory.ValidatorSet(ctx, t, totalVals, defaultVotingPower) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) + trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals, false) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, privVals, defaultEvidenceTime) require.NoError(t, err) trustedCommit := trustedExtCommit.StripExtensions() diff --git a/internal/state/execution.go b/internal/state/execution.go index 2c88c793b..68de931db 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -3,6 +3,7 @@ package state import ( "bytes" "context" + "errors" "fmt" "time" @@ -102,13 +103,12 @@ func (blockExec *BlockExecutor) CreateProposalBlock( txs := blockExec.mempool.ReapMaxBytesMaxGas(maxDataBytes, maxGas) commit := lastExtCommit.StripExtensions() block := state.MakeBlock(height, txs, commit, evidence, proposerAddr) - rpp, err := blockExec.appClient.PrepareProposal( ctx, &abci.RequestPrepareProposal{ MaxTxBytes: maxDataBytes, Txs: block.Txs.ToSliceOfBytes(), - LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight), + LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight, state.ConsensusParams.Vote), ByzantineValidators: block.Evidence.ToABCI(), Height: block.Height, Time: block.Time, @@ -321,7 +321,7 @@ func (blockExec *BlockExecutor) VerifyVoteExtension(ctx context.Context, vote *t } if !resp.IsOK() { - return types.ErrVoteInvalidExtension + return errors.New("invalid vote extension") } return nil @@ -428,7 +428,7 @@ func buildLastCommitInfo(block *types.Block, store Store, initialHeight int64) a // data, it returns an empty record. // // Assumes that the commit signatures are sorted according to validator index. -func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64) abci.ExtendedCommitInfo { +func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64, vp types.VoteParams) abci.ExtendedCommitInfo { if ec.Height < initialHeight { // There are no extended commits for heights below the initial height. return abci.ExtendedCommitInfo{} @@ -466,12 +466,12 @@ func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeigh } var ext []byte - if ecs.BlockIDFlag == types.BlockIDFlagCommit { - // We only care about vote extensions if a validator has voted to - // commit. - ext = ecs.Extension + if err := ecs.EnsureExtension(); err != nil && vp.RequireExtensions(ec.Height) { + panic(fmt.Errorf("commit at height %d received with missing vote extensions data", ec.Height)) } + ext = ecs.Extension + votes[i] = abci.ExtendedVoteInfo{ Validator: types.TM2PB.Validator(val), SignedLastBlock: ecs.BlockIDFlag != types.BlockIDFlagAbsent, diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 904fb2b74..38a829f8d 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -855,7 +855,7 @@ func mockLB(ctx context.Context, t *testing.T, height int64, time time.Time, las header.NextValidatorsHash = nextVals.Hash() header.ConsensusHash = types.DefaultConsensusParams().HashConsensusParams() lastBlockID = factory.MakeBlockIDWithHash(header.Hash()) - voteSet := types.NewVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals) + voteSet := types.NewVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals, false) extCommit, err := factory.MakeExtendedCommit(ctx, lastBlockID, height, 0, voteSet, currentPrivVals, time) require.NoError(t, err) return nextVals, nextPrivVals, &types.LightBlock{ diff --git a/node/node_test.go b/node/node_test.go index b1d7a9481..2386f6884 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -526,7 +526,7 @@ func TestMaxProposalBlockSize(t *testing.T) { } state.ChainID = maxChainID - voteSet := types.NewVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators) + voteSet := types.NewVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators, false) // add maximum amount of signatures to a single commit for i := 0; i < types.MaxVotesCount; i++ { diff --git a/test/e2e/runner/evidence.go b/test/e2e/runner/evidence.go index a71ea14fb..551906c1e 100644 --- a/test/e2e/runner/evidence.go +++ b/test/e2e/runner/evidence.go @@ -165,7 +165,7 @@ func generateLightClientAttackEvidence( // create a commit for the forged header blockID := makeBlockID(header.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals, false) commit, err := factory.MakeExtendedCommit(ctx, blockID, forgedHeight, 0, voteSet, pv, forgedTime) if err != nil { diff --git a/types/block.go b/types/block.go index 32a4f9a0a..b99d8a12a 100644 --- a/types/block.go +++ b/types/block.go @@ -757,22 +757,25 @@ func (ecs ExtendedCommitSig) ValidateBasic() error { if len(ecs.Extension) > MaxVoteExtensionSize { return fmt.Errorf("vote extension is too big (max: %d)", MaxVoteExtensionSize) } - if len(ecs.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(ecs.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } return nil } - // We expect there to not be any vote extension or vote extension signature - // on nil or absent votes. - if len(ecs.Extension) != 0 { - return fmt.Errorf("vote extension is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + if len(ecs.ExtensionSignature) == 0 && len(ecs.Extension) != 0 { + return fmt.Errorf("vote extension signature absent on vote with extension") } - if len(ecs.ExtensionSignature) != 0 { - return fmt.Errorf("vote extension signature is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + return nil +} + +// EnsureExtensions validates that a vote extensions signature is present for +// this ExtendedCommitSig. +func (ecs ExtendedCommitSig) EnsureExtension() error { + if ecs.BlockIDFlag == BlockIDFlagCommit { + if len(ecs.ExtensionSignature) == 0 { + return errors.New("vote extension signature is missing") + } } return nil } @@ -1014,16 +1017,16 @@ func (ec *ExtendedCommit) Clone() *ExtendedCommit { } // ToVoteSet constructs a VoteSet from the Commit and validator set. -// Panics if signatures from the commit can't be added to the voteset. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. // Inverse of VoteSet.MakeExtendedCommit(). -func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { - voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) +func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet, requireExtensions bool) *VoteSet { + voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals, requireExtensions) for idx, ecs := range ec.ExtendedSignatures { if ecs.BlockIDFlag == BlockIDFlagAbsent { continue // OK, some precommits can be missing. } vote := ec.GetExtendedVote(int32(idx)) - if err := vote.ValidateWithExtension(); err != nil { + if err := vote.ValidateBasic(); err != nil { panic(fmt.Errorf("failed to validate vote reconstructed from LastCommit: %w", err)) } added, err := voteSet.AddVote(vote) @@ -1034,6 +1037,38 @@ func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet return voteSet } +// ToVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the commit can't be added to the voteset. +// Inverse of VoteSet.MakeCommit(). +func (commit *Commit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewVoteSet(chainID, commit.Height, commit.Round, tmproto.PrecommitType, vals, false) + for idx, cs := range commit.Signatures { + if cs.BlockIDFlag == BlockIDFlagAbsent { + continue // OK, some precommits can be missing. + } + vote := commit.GetVote(int32(idx)) + if err := vote.ValidateBasic(); err != nil { + panic(fmt.Errorf("failed to validate vote reconstructed from commit: %w", err)) + } + added, err := voteSet.AddVote(vote) + if !added || err != nil { + panic(fmt.Errorf("failed to reconstruct vote set from commit: %w", err)) + } + } + return voteSet +} + +// EnsureExtensions validates that a vote extensions signature is present for +// every ExtendedCommitSig in the ExtendedCommit. +func (ec *ExtendedCommit) EnsureExtensions() error { + for _, ecs := range ec.ExtendedSignatures { + if err := ecs.EnsureExtension(); err != nil { + return err + } + } + return nil +} + // StripExtensions converts an ExtendedCommit to a Commit by removing all vote // extension-related fields. func (ec *ExtendedCommit) StripExtensions() *Commit { diff --git a/types/block_test.go b/types/block_test.go index 09a8b602e..1ebca3ae4 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -556,33 +556,128 @@ func TestBlockMaxDataBytesNoEvidence(t *testing.T) { } } +// TestVoteSetToExtendedCommit tests that the extended commit produced from a +// vote set contains the same vote information as the vote set. The test ensures +// that the MakeExtendedCommit method behaves as expected, whether vote extensions +// are present in the original votes or not. +func TestVoteSetToExtendedCommit(t *testing.T) { + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + + t.Run(testCase.name, func(t *testing.T) { + blockID := makeBlockIDRandom() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + valSet, vals := randValidatorPrivValSet(ctx, t, 10, 1) + voteSet := NewVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet, testCase.includeExtension) + for i := 0; i < len(vals); i++ { + pubKey, err := vals[i].GetPubKey(ctx) + require.NoError(t, err) + vote := &Vote{ + ValidatorAddress: pubKey.Address(), + ValidatorIndex: int32(i), + Height: 3, + Round: 1, + Type: tmproto.PrecommitType, + BlockID: blockID, + Timestamp: time.Now(), + } + v := vote.ToProto() + err = vals[i].SignVote(ctx, voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + if testCase.includeExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + added, err := voteSet.AddVote(vote) + require.NoError(t, err) + require.True(t, added) + } + ec := voteSet.MakeExtendedCommit() + + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := ec.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + } + }) + } +} + +// TestExtendedCommitToVoteSet tests that the vote set produced from an extended commit +// contains the same vote information as the extended commit. The test ensures +// that the ToVoteSet method behaves as expected, whether vote extensions +// are present in the original votes or not. func TestExtendedCommitToVoteSet(t *testing.T) { - lastID := makeBlockIDRandom() - h := int64(3) + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + lastID := makeBlockIDRandom() + h := int64(3) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - voteSet, valSet, vals := randVoteSet(ctx, t, h-1, 1, tmproto.PrecommitType, 10, 1) - extCommit, err := makeExtCommit(ctx, lastID, h-1, 1, voteSet, vals, time.Now()) - assert.NoError(t, err) + voteSet, valSet, vals := randVoteSet(ctx, t, h-1, 1, tmproto.PrecommitType, 10, 1) + extCommit, err := makeExtCommit(ctx, lastID, h-1, 1, voteSet, vals, time.Now()) + assert.NoError(t, err) - chainID := voteSet.ChainID() - voteSet2 := extCommit.ToVoteSet(chainID, valSet) + if !testCase.includeExtension { + for i := 0; i < len(vals); i++ { + v := voteSet.GetByIndex(int32(i)) + v.Extension = nil + v.ExtensionSignature = nil + extCommit.ExtendedSignatures[i].Extension = nil + extCommit.ExtendedSignatures[i].ExtensionSignature = nil + } + } - for i := int32(0); int(i) < len(vals); i++ { - vote1 := voteSet.GetByIndex(i) - vote2 := voteSet2.GetByIndex(i) - vote3 := extCommit.GetExtendedVote(i) + chainID := voteSet.ChainID() + voteSet2 := extCommit.ToVoteSet(chainID, valSet, testCase.includeExtension) - vote1bz, err := vote1.ToProto().Marshal() - require.NoError(t, err) - vote2bz, err := vote2.ToProto().Marshal() - require.NoError(t, err) - vote3bz, err := vote3.ToProto().Marshal() - require.NoError(t, err) - assert.Equal(t, vote1bz, vote2bz) - assert.Equal(t, vote1bz, vote3bz) + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := voteSet2.GetByIndex(i) + vote3 := extCommit.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + vote3bz, err := vote3.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + assert.Equal(t, vote1bz, vote3bz) + } + }) } } diff --git a/types/params.go b/types/params.go index e8ee6fcdf..ef14db605 100644 --- a/types/params.go +++ b/types/params.go @@ -43,6 +43,7 @@ type ConsensusParams struct { Version VersionParams `json:"version"` Synchrony SynchronyParams `json:"synchrony"` Timeout TimeoutParams `json:"timeout"` + Vote VoteParams `json:"vote"` } // HashedParams is a subset of ConsensusParams. @@ -96,6 +97,20 @@ type TimeoutParams struct { BypassCommitTimeout bool `json:"bypass_commit_timeout"` } +// VoteParams configure validity rules of the votes within Tendermint consensus. +type VoteParams struct { + ExtensionRequireHeight int64 `json:"extension_require_height"` +} + +// RequireExtensions returns true if vote extensions are required at height h +// and false otherwise. +func (v VoteParams) RequireExtensions(h int64) bool { + if v.ExtensionRequireHeight == 0 { + return false + } + return v.ExtensionRequireHeight <= h +} + // DefaultConsensusParams returns a default ConsensusParams. func DefaultConsensusParams() *ConsensusParams { return &ConsensusParams{ @@ -105,6 +120,7 @@ func DefaultConsensusParams() *ConsensusParams { Version: DefaultVersionParams(), Synchrony: DefaultSynchronyParams(), Timeout: DefaultTimeoutParams(), + Vote: DefaultVoteParams(), } } @@ -176,6 +192,13 @@ func DefaultTimeoutParams() TimeoutParams { } } +func DefaultVoteParams() VoteParams { + return VoteParams{ + // When set to 0, vote extensions are not required. + ExtensionRequireHeight: 0, + } +} + // TimeoutParamsOrDefaults returns the SynchronyParams, filling in any zero values // with the Tendermint defined default values. func (t TimeoutParams) TimeoutParamsOrDefaults() TimeoutParams { diff --git a/types/vote.go b/types/vote.go index 446de130a..83666294d 100644 --- a/types/vote.go +++ b/types/vote.go @@ -27,7 +27,7 @@ var ( ErrVoteInvalidBlockHash = errors.New("invalid block hash") ErrVoteNonDeterministicSignature = errors.New("non-deterministic signature") ErrVoteNil = errors.New("nil vote") - ErrVoteInvalidExtension = errors.New("invalid vote extension") + ErrVoteExtensionAbsent = errors.New("vote extension absent") ) type ErrVoteConflictingVotes struct { @@ -120,13 +120,8 @@ func (vote *Vote) ExtendedCommitSig() ExtendedCommitSig { return NewExtendedCommitSigAbsent() } - cs := vote.CommitSig() - if vote.BlockID.IsComplete() && len(vote.ExtensionSignature) == 0 { - panic(fmt.Sprintf("Invalid vote %v - BlockID is complete but missing vote extension signature", vote)) - } - return ExtendedCommitSig{ - CommitSig: cs, + CommitSig: vote.CommitSig(), Extension: vote.Extension, ExtensionSignature: vote.ExtensionSignature, } @@ -230,11 +225,11 @@ func (vote *Vote) Verify(chainID string, pubKey crypto.PubKey) error { return err } -// VerifyWithExtension performs the same verification as Verify, but +// VerifyVoteAndExtension performs the same verification as Verify, but // additionally checks whether the vote extension signature corresponds to the // given chain ID and public key. We only verify vote extension signatures for // precommits. -func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) error { +func (vote *Vote) VerifyVoteAndExtension(chainID string, pubKey crypto.PubKey) error { v, err := vote.verifyAndReturnProto(chainID, pubKey) if err != nil { return err @@ -249,6 +244,20 @@ func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) erro return nil } +// VerifyExtension checks whether the vote extension signature corresponds to the +// given chain ID and public key. +func (vote *Vote) VerifyExtension(chainID string, pubKey crypto.PubKey) error { + if vote.Type != tmproto.PrecommitType || vote.BlockID.IsNil() { + return nil + } + v := vote.ToProto() + extSignBytes := VoteExtensionSignBytes(chainID, v) + if !pubKey.VerifySignature(extSignBytes, vote.ExtensionSignature) { + return ErrVoteInvalidSignature + } + return nil +} + // ValidateBasic checks whether the vote is well-formed. It does not, however, // check vote extensions - for vote validation with vote extension validation, // use ValidateWithExtension. @@ -306,30 +315,34 @@ func (vote *Vote) ValidateBasic() error { } } - return nil -} - -// ValidateWithExtension performs the same validations as ValidateBasic, but -// additionally checks whether a vote extension signature is present. This -// function is used in places where vote extension signatures are expected. -func (vote *Vote) ValidateWithExtension() error { - if err := vote.ValidateBasic(); err != nil { - return err - } - - // We should always see vote extension signatures in non-nil precommits if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() { - if len(vote.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(vote.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } + if len(vote.ExtensionSignature) == 0 && len(vote.Extension) != 0 { + return fmt.Errorf("vote extension signature absent on vote with extension") + } } return nil } +// EnsureExtension checks for the presence of extensions signature data +// on precommit vote types. +func (vote *Vote) EnsureExtension() error { + // We should always see vote extension signatures in non-nil precommits + if vote.Type != tmproto.PrecommitType { + return nil + } + if vote.BlockID.IsNil() { + return nil + } + if len(vote.ExtensionSignature) > 0 { + return nil + } + return ErrVoteExtensionAbsent +} + // ToProto converts the handwritten type to proto generated type // return type, nil if everything converts safely, otherwise nil, error func (vote *Vote) ToProto() *tmproto.Vote { diff --git a/types/vote_set.go b/types/vote_set.go index 224d4e4f8..0905d651d 100644 --- a/types/vote_set.go +++ b/types/vote_set.go @@ -53,11 +53,12 @@ const ( NOTE: Assumes that the sum total of voting power does not exceed MaxUInt64. */ type VoteSet struct { - chainID string - height int64 - round int32 - signedMsgType tmproto.SignedMsgType - valSet *ValidatorSet + chainID string + height int64 + round int32 + signedMsgType tmproto.SignedMsgType + valSet *ValidatorSet + requireExtensions bool mtx sync.Mutex votesBitArray *bits.BitArray @@ -70,22 +71,23 @@ type VoteSet struct { // Constructs a new VoteSet struct used to accumulate votes for given height/round. func NewVoteSet(chainID string, height int64, round int32, - signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { + signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet, requireExtensions bool) *VoteSet { if height == 0 { panic("Cannot make VoteSet for height == 0, doesn't make sense.") } return &VoteSet{ - chainID: chainID, - height: height, - round: round, - signedMsgType: signedMsgType, - valSet: valSet, - votesBitArray: bits.NewBitArray(valSet.Size()), - votes: make([]*Vote, valSet.Size()), - sum: 0, - maj23: nil, - votesByBlock: make(map[string]*blockVotes, valSet.Size()), - peerMaj23s: make(map[string]BlockID), + chainID: chainID, + height: height, + round: round, + signedMsgType: signedMsgType, + valSet: valSet, + requireExtensions: requireExtensions, + votesBitArray: bits.NewBitArray(valSet.Size()), + votes: make([]*Vote, valSet.Size()), + sum: 0, + maj23: nil, + votesByBlock: make(map[string]*blockVotes, valSet.Size()), + peerMaj23s: make(map[string]BlockID), } } @@ -194,8 +196,14 @@ func (voteSet *VoteSet) addVote(vote *Vote) (added bool, err error) { } // Check signature. - if err := vote.VerifyWithExtension(voteSet.chainID, val.PubKey); err != nil { - return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + if voteSet.requireExtensions || len(vote.ExtensionSignature) > 0 { + if err := vote.VerifyVoteAndExtension(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } + } else { + if err := vote.Verify(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } } // Add vote and get conflicting vote if any. diff --git a/types/vote_set_test.go b/types/vote_set_test.go index 8d166d508..b80251413 100644 --- a/types/vote_set_test.go +++ b/types/vote_set_test.go @@ -498,6 +498,81 @@ func TestVoteSet_MakeCommit(t *testing.T) { } } +// TestVoteSet_RequireExtensions tests that the vote set correctly validates +// vote extensions data when either required or not required. +func TestVoteSet_RequireExtensions(t *testing.T) { + for _, tc := range []struct { + name string + requireExtensions bool + addExtension bool + exepectError bool + }{ + { + name: "no extension but expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "invalid extensions but not expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "no extension and not expected", + requireExtensions: false, + addExtension: false, + exepectError: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + height, round := int64(1), int32(0) + valSet, privValidators := randValidatorPrivValSet(ctx, t, 5, 10) + voteSet := NewVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet, tc.requireExtensions) + + val0 := privValidators[0] + + val0p, err := val0.GetPubKey(ctx) + require.NoError(t, err) + val0Addr := val0p.Address() + blockHash := crypto.CRandBytes(32) + blockPartsTotal := uint32(123) + blockPartSetHeader := PartSetHeader{blockPartsTotal, crypto.CRandBytes(32)} + + vote := &Vote{ + ValidatorAddress: val0Addr, + ValidatorIndex: 0, + Height: height, + Round: round, + Type: tmproto.PrecommitType, + Timestamp: tmtime.Now(), + BlockID: BlockID{blockHash, blockPartSetHeader}, + } + v := vote.ToProto() + err = val0.SignVote(ctx, voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + + if tc.addExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + + added, err := voteSet.AddVote(vote) + if tc.exepectError { + require.Error(t, err) + require.False(t, added) + } else { + require.NoError(t, err) + require.True(t, added) + } + }) + } +} + // NOTE: privValidators are in order func randVoteSet( ctx context.Context, @@ -510,7 +585,7 @@ func randVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := randValidatorPrivValSet(ctx, t, numValidators, votingPower) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators + return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet, true), valSet, privValidators } func deterministicVoteSet( @@ -523,7 +598,7 @@ func deterministicVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := deterministicValidatorSet(ctx, t) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators + return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet, true), valSet, privValidators } func randValidatorPrivValSet(ctx context.Context, t testing.TB, numValidators int, votingPower int64) (*ValidatorSet, []PrivValidator) { diff --git a/types/vote_test.go b/types/vote_test.go index 70cd91381..1b7bb3f4f 100644 --- a/types/vote_test.go +++ b/types/vote_test.go @@ -267,7 +267,7 @@ func TestVoteExtension(t *testing.T) { if tc.includeSignature { vote.ExtensionSignature = v.ExtensionSignature } - err = vote.VerifyWithExtension("test_chain_id", pk) + err = vote.VerifyExtension("test_chain_id", pk) if tc.expectError { require.Error(t, err) } else { @@ -361,7 +361,7 @@ func TestValidVotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", tc.vote) tc.malleateVote(tc.vote) require.NoError(t, tc.vote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.NoError(t, tc.vote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, tc.vote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -387,13 +387,13 @@ func TestInvalidVotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s in invalid prevote", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s in invalid prevote", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s in invalid prevote", tc.name) precommit := examplePrecommit(t) signVote(ctx, t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s in invalid precommit", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s in invalid precommit", tc.name) + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s in invalid precommit", tc.name) } } @@ -414,7 +414,7 @@ func TestInvalidPrevotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -431,18 +431,44 @@ func TestInvalidPrecommitExtensions(t *testing.T) { v.Extension = []byte("extension") v.ExtensionSignature = nil }}, - // TODO(thane): Re-enable once https://github.com/tendermint/tendermint/issues/8272 is resolved - //{"missing vote extension signature", func(v *Vote) { v.ExtensionSignature = nil }}, {"oversized vote extension signature", func(v *Vote) { v.ExtensionSignature = make([]byte, MaxSignatureSize+1) }}, } for _, tc := range testCases { precommit := examplePrecommit(t) signVote(ctx, t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) - // We don't expect an error from ValidateBasic, because it doesn't - // handle vote extensions. - require.NoError(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + // ValidateBasic ensures that vote extensions, if present, are well formed + require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) + } +} + +func TestEnsureVoteExtension(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + privVal := NewMockPV() + + testCases := []struct { + name string + malleateVote func(*Vote) + expectError bool + }{ + {"vote extension signature absent", func(v *Vote) { + v.Extension = nil + v.ExtensionSignature = nil + }, true}, + {"vote extension signature present", func(v *Vote) { + v.ExtensionSignature = []byte("extension signature") + }, false}, + } + for _, tc := range testCases { + precommit := examplePrecommit(t) + signVote(ctx, t, privVal, "test_chain_id", precommit) + tc.malleateVote(precommit) + if tc.expectError { + require.Error(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } else { + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } } }