diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 8fc5bff9d..be1522cbf 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -2260,19 +2260,36 @@ func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { } } -func TestVerifyVoteExtensionCalled(t *testing.T) { +func TestVoteExtensionRequiredHeight(t *testing.T) { for _, testCase := range []struct { - name string - hasExtension bool - initialRequireHeight int64 + name string + initialRequiredHeight int64 + hasExtension bool + expectSuccessfulRound bool }{ { - name: "called when extension present", - hasExtension: true, + name: "extension present but not required", + hasExtension: true, + initialRequiredHeight: 0, + expectSuccessfulRound: true, }, { - name: "not called when extension absent", - hasExtension: false, + name: "extension absent but not required", + hasExtension: false, + initialRequiredHeight: 0, + expectSuccessfulRound: true, + }, + { + name: "extension present and required", + hasExtension: true, + initialRequiredHeight: 1, + expectSuccessfulRound: true, + }, + { + name: "extension absent but required", + hasExtension: false, + initialRequiredHeight: 1, + expectSuccessfulRound: false, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -2294,10 +2311,11 @@ func TestVerifyVoteExtensionCalled(t *testing.T) { } m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() - l := log.NewTestingLogger(t) - cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m, logger: l, validators: numValidators}) + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m, validators: numValidators}) + cs1.state.ConsensusParams.Vote.ExtensionRequireHeight = testCase.initialRequiredHeight height, round := cs1.Height, cs1.Round + timeoutCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) pv1, err := cs1.privValidator.GetPubKey(ctx) @@ -2323,7 +2341,7 @@ func TestVerifyVoteExtensionCalled(t *testing.T) { if testCase.hasExtension { ext = []byte("extension") } - // sign all of the precommits + for _, vs := range vss[1:] { vote, err := vs.signVote(ctx, tmproto.PrecommitType, config.ChainID(), blockID, ext) if !testCase.hasExtension { @@ -2332,9 +2350,14 @@ func TestVerifyVoteExtensionCalled(t *testing.T) { require.NoError(t, err) addVotes(cs1, vote) } - ensurePrecommit(t, voteCh, height, round) + if testCase.expectSuccessfulRound { + ensurePrecommit(t, voteCh, height, round) + height++ + } else { + ensureNewTimeout(t, timeoutCh, height, round, cs1.state.ConsensusParams.Timeout.VoteTimeout(round).Nanoseconds()) + round++ + } - height++ ensureNewRound(t, newRoundCh, height, round) m.AssertExpectations(t) })