From 431c85be6025bbac5b4d3fc06882c4122b62d0a2 Mon Sep 17 00:00:00 2001 From: William Banfield Date: Thu, 31 Mar 2022 15:13:37 -0400 Subject: [PATCH] add tests for vote extension cases --- internal/consensus/common_test.go | 7 +- internal/consensus/state.go | 1 - internal/consensus/state_test.go | 243 ++++++++++++++++++++++++++++++ types/test_util.go | 1 + types/vote_set_test.go | 1 + types/vote_test.go | 73 +++++++++ 6 files changed, 322 insertions(+), 4 deletions(-) diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index dac14bd43..756fa9eef 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -112,7 +112,8 @@ func (vs *validatorStub) signVote( ctx context.Context, voteType tmproto.SignedMsgType, chainID string, - blockID types.BlockID) (*types.Vote, error) { + blockID types.BlockID, + voteExtension []byte) (*types.Vote, error) { pubKey, err := vs.PrivValidator.GetPubKey(ctx) if err != nil { @@ -127,7 +128,7 @@ func (vs *validatorStub) signVote( Timestamp: vs.clock.Now(), ValidatorAddress: pubKey.Address(), ValidatorIndex: vs.Index, - Extension: []byte("extension"), + Extension: voteExtension, } v := vote.ToProto() if err = vs.PrivValidator.SignVote(ctx, chainID, v); err != nil { @@ -157,7 +158,7 @@ func signVote( chainID string, blockID types.BlockID) *types.Vote { - v, err := vs.signVote(ctx, voteType, chainID, blockID) + v, err := vs.signVote(ctx, voteType, chainID, blockID, []byte("extension")) require.NoError(t, err, "failed to sign vote") vs.lastVote = v diff --git a/internal/consensus/state.go b/internal/consensus/state.go index f72293aba..edb8473ee 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -911,7 +911,6 @@ func (cs *State) receiveRoutine(ctx context.Context, maxSteps int) { if err := cs.wal.Write(mi); err != nil { cs.logger.Error("failed writing to WAL", "err", err) } - // handles proposals, block parts, votes // may generate internal events (votes, complete proposals, 2/3 majorities) cs.handleMsg(ctx, mi) diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 03bb85cf2..2cecc421f 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -2015,6 +2015,237 @@ func TestFinalizeBlockCalled(t *testing.T) { } } +// TestExtendVoteCalled tests that the vote extension methods are called at the +// correct point in the consensus algorithm. +func TestExtendVoteCalled(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m := abcimocks.NewBaseMock() + m.On("ProcessProposal", mock.Anything).Return(abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}) + m.On("ExtendVote", mock.Anything).Return(abci.ResponseExtendVote{ + VoteExtension: []byte("extension"), + }) + m.On("VerifyVoteExtension", mock.Anything).Return(abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }) + m.On("FinalizeBlock", mock.Anything).Return(abci.ResponseFinalizeBlock{}).Maybe() + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) + height, round := cs1.Height, cs1.Round + + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) + + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + + m.AssertNotCalled(t, "ExtendVote", mock.Anything) + + rs := cs1.GetRoundState() + + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) + + ensurePrecommit(t, voteCh, height, round) + + m.AssertCalled(t, "ExtendVote", abci.RequestExtendVote{ + Height: height, + Hash: blockID.Hash, + }) + + m.AssertCalled(t, "VerifyVoteExtension", abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + + signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) + ensureNewRound(t, newRoundCh, height+1, 0) + m.AssertExpectations(t) + + // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold + // is observed by the consensus engine. + for _, pv := range vss[:3] { + pv, err := pv.GetPubKey(ctx) + require.NoError(t, err) + addr := pv.Address() + m.AssertCalled(t, "VerifyVoteExtension", abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + } + +} + +// TestVerifyVoteExtensionNotCalledOnAbsentPrecommit tests that the VerifyVoteExtension +// method is not called for a validator's vote that is never delivered. +func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + m := abcimocks.NewBaseMock() + m.On("ProcessProposal", mock.Anything).Return(abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}) + m.On("ExtendVote", mock.Anything).Return(abci.ResponseExtendVote{ + VoteExtension: []byte("extension"), + }) + m.On("VerifyVoteExtension", mock.Anything).Return(abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }) + m.On("FinalizeBlock", mock.Anything).Return(abci.ResponseFinalizeBlock{}).Maybe() + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) + height, round := cs1.Height, cs1.Round + + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) + + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + rs := cs1.GetRoundState() + + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[2:]...) + ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) + + ensurePrecommit(t, voteCh, height, round) + + m.AssertCalled(t, "ExtendVote", abci.RequestExtendVote{ + Height: height, + Hash: blockID.Hash, + }) + + m.AssertCalled(t, "VerifyVoteExtension", abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + + signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[2:]...) + ensureNewRound(t, newRoundCh, height+1, 0) + m.AssertExpectations(t) + + // vss[1] did not issue a precommit for the block, ensure that a vote extension + // for its address was not sent to the application. + pv, err := vss[1].GetPubKey(ctx) + require.NoError(t, err) + addr = pv.Address() + + m.AssertNotCalled(t, "VerifyVoteExtension", abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + +} + +// TestPrepareProposalReceivesVoteExtensions tests that the PrepareProposal method +// is called with the vote extensions from the previous height. The test functions +// be completing a consensus height with a mock application as the proposer. The +// test then proceeds to fail sever rounds of consensus until the mock application +// is the proposer again and ensures that the mock application receives the set of +// vote extensions from the previous consensus instance. +func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // create a list of vote extensions, one for each validator. + voteExtensions := [][]byte{ + []byte("extension 0"), + []byte("extension 1"), + []byte("extension 2"), + []byte("extension 3"), + } + + m := abcimocks.NewBaseMock() + m.On("ExtendVote", mock.Anything).Return(abci.ResponseExtendVote{ + VoteExtension: voteExtensions[0], + }) + m.On("PrepareProposal", mock.Anything).Return(abci.ResponsePrepareProposal{ + ModifiedTxStatus: abci.ResponsePrepareProposal_UNMODIFIED, + }).Once() + + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) + height, round := cs1.Height, cs1.Round + + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + + startTestRound(ctx, cs1, height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + + rs := cs1.GetRoundState() + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + + // create a precommit for each validator with the associated vote extension. + for i, vs := range vss[1:] { + signAddPrecommitWithExtension(ctx, t, cs1, config.ChainID(), blockID, voteExtensions[i+1], vs) + } + + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) + + // ensure that the height is commited. + ensurePrecommit(t, voteCh, height, round) + validatePrecommit(ctx, t, cs1, round, round, vss[0], blockID.Hash, blockID.Hash) + incrementHeight(vss[1:]...) + + height++ + round = 0 + ensureNewRound(t, newRoundCh, height, round) + incrementRound(vss[1:]...) + incrementRound(vss[1:]...) + incrementRound(vss[1:]...) + round = 3 + + // capture the prepare proposal request. + rpp := abci.RequestPrepareProposal{} + m.On("PrepareProposal", mock.MatchedBy(func(r abci.RequestPrepareProposal) bool { + rpp = r + return true + })).Return(abci.ResponsePrepareProposal{ModifiedTxStatus: abci.ResponsePrepareProposal_UNMODIFIED}) + + signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), types.BlockID{}, vss[1:]...) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + + // ensure that the proposer received the list of vote extensions from the + // previous height. + for i := range vss { + require.Equal(t, rpp.LocalLastCommit.Votes[i].VoteExtension, voteExtensions[i]) + } +} + // 4 vals, 3 Nil Precommits at P0 // What we want: // P0 waits for timeoutPrecommit before starting next round @@ -2719,3 +2950,15 @@ func subscribe( }() return ch } + +func signAddPrecommitWithExtension(ctx context.Context, + t *testing.T, + cs *State, + chainID string, + blockID types.BlockID, + extension []byte, + stub *validatorStub) { + v, err := stub.signVote(ctx, tmproto.PrecommitType, chainID, blockID, extension) + require.NoError(t, err, "failed to sign vote") + addVotes(cs, v) +} diff --git a/types/test_util.go b/types/test_util.go index dbd3f81ec..55f56ebf5 100644 --- a/types/test_util.go +++ b/types/test_util.go @@ -43,5 +43,6 @@ func signAddVote(ctx context.Context, privVal PrivValidator, vote *Vote, voteSet return false, err } vote.Signature = v.Signature + vote.ExtensionSignature = v.ExtensionSignature return voteSet.AddVote(vote) } diff --git a/types/vote_set_test.go b/types/vote_set_test.go index 4de9b1837..d92bd9f43 100644 --- a/types/vote_set_test.go +++ b/types/vote_set_test.go @@ -127,6 +127,7 @@ func TestVoteSet_AddVote_Bad(t *testing.T) { t.Errorf("expected VoteSet.Add to fail, wrong type") } } + } func TestVoteSet_2_3Majority(t *testing.T) { diff --git a/types/vote_test.go b/types/vote_test.go index 29f324aa8..1cf5d806e 100644 --- a/types/vote_test.go +++ b/types/vote_test.go @@ -13,6 +13,7 @@ import ( "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/tmhash" "github.com/tendermint/tendermint/internal/libs/protoio" + tmtime "github.com/tendermint/tendermint/libs/time" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) @@ -202,6 +203,78 @@ func TestVoteVerifySignature(t *testing.T) { require.True(t, valid) } +// TestVoteExtension tests that the vote verification behaves correctly in each case +// of vote extension being set on the vote. +func TestVoteExtension(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCases := []struct { + name string + extension []byte + includeSignature bool + expectError bool + }{ + { + name: "all fields present", + extension: []byte("extension"), + includeSignature: true, + expectError: false, + }, + { + name: "no extension signature", + extension: []byte("extension"), + includeSignature: false, + expectError: true, + }, + { + name: "empty extension", + includeSignature: true, + expectError: false, + }, + { + name: "no extension and no signature", + includeSignature: false, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + height, round := int64(1), int32(0) + privVal := NewMockPV() + pk, err := privVal.GetPubKey(ctx) + require.NoError(t, err) + blk := Block{} + ps, err := blk.MakePartSet(BlockPartSizeBytes) + require.NoError(t, err) + vote := &Vote{ + ValidatorAddress: pk.Address(), + ValidatorIndex: 0, + Height: height, + Round: round, + Timestamp: tmtime.Now(), + Type: tmproto.PrecommitType, + BlockID: BlockID{blk.Hash(), ps.Header()}, + } + + v := vote.ToProto() + err = privVal.SignVote(ctx, "test_chain_id", v) + require.NoError(t, err) + vote.Signature = v.Signature + if tc.includeSignature { + vote.ExtensionSignature = v.ExtensionSignature + } + err = vote.Verify("test_chain_id", pk) + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func TestIsVoteTypeValid(t *testing.T) { tc := []struct { name string