From 397f0840f6e3fb82fb75f340962a2fc8f86bc3a2 Mon Sep 17 00:00:00 2001 From: William Banfield Date: Mon, 16 May 2022 17:13:14 -0400 Subject: [PATCH] do not call extension methods when not enabled --- internal/consensus/state.go | 16 ++-- internal/consensus/state_test.go | 138 +++++++++++++++++++------------ 2 files changed, 91 insertions(+), 63 deletions(-) diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 78d6ee733..d3a91d5a9 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -2553,18 +2553,18 @@ func (cs *State) signVote( // If the signedMessageType is for precommit, // use our local precommit Timeout as the max wait time for getting a singed commit. The same goes for prevote. - timeout := cs.voteTimeout(cs.Round) - + timeout := time.Second if msgType == tmproto.PrecommitType && !vote.BlockID.IsNil() { + timeout = cs.voteTimeout(cs.Round) // if the signedMessage type is for a non-nil precommit, add // VoteExtension - ext, err := cs.blockExec.ExtendVote(ctx, vote) - if err != nil { - return nil, err + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(cs.Height) { + ext, err := cs.blockExec.ExtendVote(ctx, vote) + if err != nil { + return nil, err + } + vote.Extension = ext } - vote.Extension = ext - } else { - timeout = time.Second } v := vote.ToProto() diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 267a0fdb9..6292b4305 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -2026,71 +2026,98 @@ func TestFinalizeBlockCalled(t *testing.T) { } } -// TestExtendVoteCalled tests that the vote extension methods are called at the -// correct point in the consensus algorithm. -func TestExtendVoteCalled(t *testing.T) { - config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// TestExtendVoteCalledWhenEnabled tests that the vote extension methods are called at the +// correct point in the consensus algorithm when vote extensions are enabled. +func TestExtendVoteCalledWhenEnabled(t *testing.T) { + for _, testCase := range []struct { + name string + enabled bool + }{ + { + enabled: true, + }, + { + enabled: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - m := abcimocks.NewApplication(t) - m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) - m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) - m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ - VoteExtension: []byte("extension"), - }, nil) - m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ - Status: abci.ResponseVerifyVoteExtension_ACCEPT, - }, nil) - m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() - m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() - cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) - height, round := cs1.Height, cs1.Round + m := abcimocks.NewApplication(t) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + if testCase.enabled { + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ + VoteExtension: []byte("extension"), + }, nil) + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil) + } + m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) + height, round := cs1.Height, cs1.Round + if testCase.enabled { + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = 1 + } - proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(ctx) - require.NoError(t, err) - addr := pv1.Address() - voteCh := subscribeToVoter(ctx, t, cs1, addr) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) - startTestRound(ctx, cs1, cs1.Height, round) - ensureNewRound(t, newRoundCh, height, round) - ensureNewProposal(t, proposalCh, height, round) + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) - m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) + m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) - rs := cs1.GetRoundState() + rs := cs1.GetRoundState() - blockID := types.BlockID{ - Hash: rs.ProposalBlock.Hash(), - PartSetHeader: rs.ProposalBlockParts.Header(), - } - signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) - ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) - ensurePrecommit(t, voteCh, height, round) + ensurePrecommit(t, voteCh, height, round) - m.AssertCalled(t, "ExtendVote", ctx, &abci.RequestExtendVote{ - Height: height, - Hash: blockID.Hash, - }) + if testCase.enabled { + m.AssertCalled(t, "ExtendVote", ctx, &abci.RequestExtendVote{ + Height: height, + Hash: blockID.Hash, + }) + } else { + m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) + } - signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) - ensureNewRound(t, newRoundCh, height+1, 0) - m.AssertExpectations(t) + signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) + ensureNewRound(t, newRoundCh, height+1, 0) + m.AssertExpectations(t) - // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold - // is observed by the consensus engine. - for _, pv := range vss[1:3] { - pv, err := pv.GetPubKey(ctx) - require.NoError(t, err) - addr := pv.Address() - m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), + // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold + // is observed by the consensus engine. + for _, pv := range vss[1:3] { + pv, err := pv.GetPubKey(ctx) + require.NoError(t, err) + addr := pv.Address() + if testCase.enabled { + m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + } else { + m.AssertNotCalled(t, "VerifyVoteExtension", mock.Anything, mock.Anything) + } + } }) } @@ -2204,6 +2231,7 @@ func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) height, round := cs1.Height, cs1.Round + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = 1 newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal)