From 0cceadf4d4c5acf9d390f22616b80070fece52d4 Mon Sep 17 00:00:00 2001 From: William Banfield <4561443+williambanfield@users.noreply.github.com> Date: Fri, 20 May 2022 17:46:52 -0400 Subject: [PATCH] abci++: add consensus parameter logic to control vote extension require height (#8547) This PR makes vote extensions optional within Tendermint. A new ConsensusParams field, called ABCIParams.VoteExtensionsEnableHeight, has been added to toggle whether or not extensions should be enabled or disabled depending on the current height of the consensus engine. Related to: #8453 --- internal/blocksync/pool.go | 6 +- internal/blocksync/reactor.go | 121 +++++--- internal/blocksync/reactor_test.go | 181 +++++++---- internal/consensus/common_test.go | 15 +- internal/consensus/msgs.go | 6 +- internal/consensus/reactor.go | 21 +- internal/consensus/reactor_test.go | 113 +++++++ internal/consensus/replay_test.go | 8 +- internal/consensus/state.go | 148 ++++++--- internal/consensus/state_test.go | 285 +++++++++++++----- internal/consensus/types/height_vote_set.go | 26 +- .../consensus/types/height_vote_set_test.go | 2 +- internal/evidence/pool_test.go | 7 +- internal/evidence/verify_test.go | 24 +- internal/state/execution.go | 22 +- internal/state/execution_test.go | 112 ++++++- internal/state/mocks/block_store.go | 7 +- internal/state/services.go | 3 +- internal/state/store.go | 19 ++ internal/state/validation_test.go | 6 +- internal/statesync/reactor_test.go | 4 +- internal/store/store.go | 90 ++++-- internal/store/store_test.go | 116 ++++++- internal/test/factory/params.go | 1 + node/node_test.go | 2 +- test/e2e/runner/evidence.go | 6 +- test/e2e/tests/app_test.go | 1 + types/block.go | 109 ++++++- types/block_test.go | 159 ++++++++-- types/evidence_test.go | 6 +- types/params.go | 24 ++ types/validation_test.go | 10 +- types/validator_set_test.go | 6 +- types/vote.go | 71 +++-- types/vote_set.go | 38 ++- types/vote_set_test.go | 90 +++++- types/vote_test.go | 50 ++- 37 files changed, 1509 insertions(+), 406 deletions(-) diff --git a/internal/blocksync/pool.go b/internal/blocksync/pool.go index 30bb6962e..64ce54dc6 100644 --- a/internal/blocksync/pool.go +++ b/internal/blocksync/pool.go @@ -280,7 +280,7 @@ func (pool *BlockPool) AddBlock(peerID types.NodeID, block *types.Block, extComm pool.mtx.Lock() defer pool.mtx.Unlock() - if block.Height != extCommit.Height { + if extCommit != nil && block.Height != extCommit.Height { return fmt.Errorf("heights don't match, not adding block (block height: %d, commit height: %d)", block.Height, extCommit.Height) } @@ -597,7 +597,9 @@ func (bpr *bpRequester) setBlock(block *types.Block, extCommit *types.ExtendedCo return false } bpr.block = block - bpr.extCommit = extCommit + if extCommit != nil { + bpr.extCommit = extCommit + } bpr.mtx.Unlock() select { diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index 144595889..6c1c060e7 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -185,31 +185,39 @@ func (r *Reactor) OnStop() { // Otherwise, we'll respond saying we do not have it. func (r *Reactor) respondToPeer(ctx context.Context, msg *bcproto.BlockRequest, peerID types.NodeID, blockSyncCh *p2p.Channel) error { block := r.store.LoadBlock(msg.Height) - if block != nil { - extCommit := r.store.LoadBlockExtendedCommit(msg.Height) - if extCommit == nil { - return fmt.Errorf("found block in store without extended commit: %v", block) - } - blockProto, err := block.ToProto() - if err != nil { - return fmt.Errorf("failed to convert block to protobuf: %w", err) - } - + if block == nil { + r.logger.Info("peer requesting a block we do not have", "peer", peerID, "height", msg.Height) return blockSyncCh.Send(ctx, p2p.Envelope{ - To: peerID, - Message: &bcproto.BlockResponse{ - Block: blockProto, - ExtCommit: extCommit.ToProto(), - }, + To: peerID, + Message: &bcproto.NoBlockResponse{Height: msg.Height}, }) } - r.logger.Info("peer requesting a block we do not have", "peer", peerID, "height", msg.Height) + state, err := r.stateStore.Load() + if err != nil { + return fmt.Errorf("loading state: %w", err) + } + var extCommit *types.ExtendedCommit + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(msg.Height) { + extCommit = r.store.LoadBlockExtendedCommit(msg.Height) + if extCommit == nil { + return fmt.Errorf("found block in store with no extended commit: %v", block) + } + } + + blockProto, err := block.ToProto() + if err != nil { + return fmt.Errorf("failed to convert block to protobuf: %w", err) + } return blockSyncCh.Send(ctx, p2p.Envelope{ - To: peerID, - Message: &bcproto.NoBlockResponse{Height: msg.Height}, + To: peerID, + Message: &bcproto.BlockResponse{ + Block: blockProto, + ExtCommit: extCommit.ToProto(), + }, }) + } // handleMessage handles an Envelope sent from a peer on a specific p2p Channel. @@ -242,12 +250,16 @@ func (r *Reactor) handleMessage(ctx context.Context, envelope *p2p.Envelope, blo "err", err) return err } - extCommit, err := types.ExtendedCommitFromProto(msg.ExtCommit) - if err != nil { - r.logger.Error("failed to convert extended commit from proto", - "peer", envelope.From, - "err", err) - return err + var extCommit *types.ExtendedCommit + if msg.ExtCommit != nil { + var err error + extCommit, err = types.ExtendedCommitFromProto(msg.ExtCommit) + if err != nil { + r.logger.Error("failed to convert extended commit from proto", + "peer", envelope.From, + "err", err) + return err + } } if err := r.pool.AddBlock(envelope.From, block, extCommit, block.Size()); err != nil { @@ -440,6 +452,8 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh lastRate = 0.0 didProcessCh = make(chan struct{}, 1) + + initialCommitHasExtensions = (r.initialState.LastBlockHeight > 0 && r.store.LoadBlockExtendedCommit(r.initialState.LastBlockHeight) != nil) ) defer trySyncTicker.Stop() @@ -463,12 +477,27 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh ) switch { - // TODO(sergio) Might be needed for implementing the upgrading solution. Remove after that - //case state.LastBlockHeight > 0 && r.store.LoadBlockExtCommit(state.LastBlockHeight) == nil: - case state.LastBlockHeight > 0 && blocksSynced == 0: - // Having state-synced, we need to blocksync at least one block + + // The case statement below is a bit confusing, so here is a breakdown + // of its logic and purpose: + // + // If VoteExtensions are enabled we cannot switch to consensus without + // the vote extension data for the previous height, i.e. state.LastBlockHeight. + // + // If extensions were required during state.LastBlockHeight and we have + // sync'd at least one block, then we are guaranteed to have extensions. + // BlockSync requires that the blocks it fetches have extensions if + // extensions were enabled during the height. + // + // If extensions were required during state.LastBlockHeight and we have + // not sync'd any blocks, then we can only transition to Consensus + // if we already had extensions for the initial height. + // If any of these conditions is not met, we continue the loop, looking + // for extensions. + case state.ConsensusParams.ABCI.VoteExtensionsEnabled(state.LastBlockHeight) && + (blocksSynced == 0 && !initialCommitHasExtensions): r.logger.Info( - "no seen commit yet", + "no extended commit yet", "height", height, "last_block_height", state.LastBlockHeight, "initial_height", state.InitialHeight, @@ -520,18 +549,19 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh // see if there are any blocks to sync first, second, extCommit := r.pool.PeekTwoBlocks() - if first == nil || second == nil || extCommit == nil { - if first != nil && extCommit == nil { - // See https://github.com/tendermint/tendermint/pull/8433#discussion_r866790631 - panic(fmt.Errorf("peeked first block without extended commit at height %d - possible node store corruption", first.Height)) - } - // we need all to sync the first block + if first != nil && extCommit == nil && + state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + // See https://github.com/tendermint/tendermint/pull/8433#discussion_r866790631 + panic(fmt.Errorf("peeked first block without extended commit at height %d - possible node store corruption", first.Height)) + } else if first == nil || second == nil { + // we need to have fetched two consecutive blocks in order to + // perform blocksync verification continue - } else { - // try again quickly next loop - didProcessCh <- struct{}{} } + // try again quickly next loop + didProcessCh <- struct{}{} + firstParts, err := first.MakePartSet(types.BlockPartSizeBytes) if err != nil { r.logger.Error("failed to make ", @@ -557,7 +587,10 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh // validate the block before we persist it err = r.blockExec.ValidateBlock(ctx, state, first) } - + if err == nil && state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + // if vote extensions were required at this height, ensure they exist. + err = extCommit.EnsureExtensions() + } // If either of the checks failed we log the error and request for a new block // at that height if err != nil { @@ -593,7 +626,15 @@ func (r *Reactor) poolRoutine(ctx context.Context, stateSynced bool, blockSyncCh r.pool.PopRequest() // TODO: batch saves so we do not persist to disk every block - r.store.SaveBlock(first, firstParts, extCommit) + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + r.store.SaveBlockWithExtendedCommit(first, firstParts, extCommit) + } else { + // We use LastCommit here instead of extCommit. extCommit is not + // guaranteed to be populated by the peer if extensions are not enabled. + // Currently, the peer should provide an extCommit even if the vote extension data are absent + // but this may change so using second.LastCommit is safer. + r.store.SaveBlock(first, firstParts, second.LastCommit) + } // TODO: Same thing for app - but we would need a way to get the hash // without persisting the state. diff --git a/internal/blocksync/reactor_test.go b/internal/blocksync/reactor_test.go index 1d4d7d4d6..0477eb45d 100644 --- a/internal/blocksync/reactor_test.go +++ b/internal/blocksync/reactor_test.go @@ -40,8 +40,6 @@ type reactorTestSuite struct { blockSyncChannels map[types.NodeID]*p2p.Channel peerChans map[types.NodeID]chan p2p.PeerUpdate peerUpdates map[types.NodeID]*p2p.PeerUpdates - - blockSync bool } func setup( @@ -69,7 +67,6 @@ func setup( blockSyncChannels: make(map[types.NodeID]*p2p.Channel, numNodes), peerChans: make(map[types.NodeID]chan p2p.PeerUpdate, numNodes), peerUpdates: make(map[types.NodeID]*p2p.PeerUpdates, numNodes), - blockSync: true, } chDesc := &p2p.ChannelDescriptor{ID: BlockSyncChannel, MessageType: new(bcproto.Message)} @@ -97,21 +94,19 @@ func setup( return rts } -func (rts *reactorTestSuite) addNode( +func makeReactor( ctx context.Context, t *testing.T, nodeID types.NodeID, genDoc *types.GenesisDoc, privVal types.PrivValidator, - maxBlockHeight int64, -) { - t.Helper() + channelCreator p2p.ChannelCreator, + peerEvents p2p.PeerEventSubscriber) *Reactor { logger := log.NewNopLogger() - rts.nodes = append(rts.nodes, nodeID) - rts.app[nodeID] = proxy.New(abciclient.NewLocalClient(logger, &abci.BaseApplication{}), logger, proxy.NopMetrics()) - require.NoError(t, rts.app[nodeID].Start(ctx)) + app := proxy.New(abciclient.NewLocalClient(logger, &abci.BaseApplication{}), logger, proxy.NopMetrics()) + require.NoError(t, app.Start(ctx)) blockDB := dbm.NewMemDB() stateDB := dbm.NewMemDB() @@ -139,7 +134,7 @@ func (rts *reactorTestSuite) addNode( blockExec := sm.NewBlockExecutor( stateStore, log.NewNopLogger(), - rts.app[nodeID], + app, mp, sm.EmptyEvidencePool{}, blockStore, @@ -147,44 +142,35 @@ func (rts *reactorTestSuite) addNode( sm.NopMetrics(), ) - var lastExtCommit *types.ExtendedCommit + return NewReactor( + logger, + stateStore, + blockExec, + blockStore, + nil, + channelCreator, + peerEvents, + true, + consensus.NopMetrics(), + nil, // eventbus, can be nil + ) +} - // The commit we are building for the current height. - seenExtCommit := &types.ExtendedCommit{} +func (rts *reactorTestSuite) addNode( + ctx context.Context, + t *testing.T, + nodeID types.NodeID, + genDoc *types.GenesisDoc, + privVal types.PrivValidator, + maxBlockHeight int64, +) { + t.Helper() - for blockHeight := int64(1); blockHeight <= maxBlockHeight; blockHeight++ { - lastExtCommit = seenExtCommit.Clone() + logger := log.NewNopLogger() - thisBlock := sf.MakeBlock(state, blockHeight, lastExtCommit.StripExtensions()) - thisParts, err := thisBlock.MakePartSet(types.BlockPartSizeBytes) - require.NoError(t, err) - blockID := types.BlockID{Hash: thisBlock.Hash(), PartSetHeader: thisParts.Header()} - - // Simulate a commit for the current height - vote, err := factory.MakeVote( - ctx, - privVal, - thisBlock.Header.ChainID, - 0, - thisBlock.Header.Height, - 0, - 2, - blockID, - time.Now(), - ) - require.NoError(t, err) - seenExtCommit = &types.ExtendedCommit{ - Height: vote.Height, - Round: vote.Round, - BlockID: blockID, - ExtendedSignatures: []types.ExtendedCommitSig{vote.ExtendedCommitSig()}, - } - - state, err = blockExec.ApplyBlock(ctx, state, blockID, thisBlock) - require.NoError(t, err) - - blockStore.SaveBlock(thisBlock, thisParts, seenExtCommit) - } + rts.nodes = append(rts.nodes, nodeID) + rts.app[nodeID] = proxy.New(abciclient.NewLocalClient(logger, &abci.BaseApplication{}), logger, proxy.NopMetrics()) + require.NoError(t, rts.app[nodeID].Start(ctx)) rts.peerChans[nodeID] = make(chan p2p.PeerUpdate) rts.peerUpdates[nodeID] = p2p.NewPeerUpdates(rts.peerChans[nodeID], 1) @@ -193,21 +179,64 @@ func (rts *reactorTestSuite) addNode( chCreator := func(ctx context.Context, chdesc *p2p.ChannelDescriptor) (*p2p.Channel, error) { return rts.blockSyncChannels[nodeID], nil } - rts.reactors[nodeID] = NewReactor( - rts.logger.With("nodeID", nodeID), - stateStore, - blockExec, - blockStore, - nil, - chCreator, - func(ctx context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] }, - rts.blockSync, - consensus.NopMetrics(), - nil, // eventbus, can be nil - ) - require.NoError(t, rts.reactors[nodeID].Start(ctx)) - require.True(t, rts.reactors[nodeID].IsRunning()) + peerEvents := func(ctx context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] } + reactor := makeReactor(ctx, t, nodeID, genDoc, privVal, chCreator, peerEvents) + + lastExtCommit := &types.ExtendedCommit{} + + state, err := reactor.stateStore.Load() + require.NoError(t, err) + for blockHeight := int64(1); blockHeight <= maxBlockHeight; blockHeight++ { + block, blockID, partSet, seenExtCommit := makeNextBlock(ctx, t, state, privVal, blockHeight, lastExtCommit) + + state, err = reactor.blockExec.ApplyBlock(ctx, state, blockID, block) + require.NoError(t, err) + + reactor.store.SaveBlockWithExtendedCommit(block, partSet, seenExtCommit) + lastExtCommit = seenExtCommit + } + + rts.reactors[nodeID] = reactor + require.NoError(t, reactor.Start(ctx)) + require.True(t, reactor.IsRunning()) +} + +func makeNextBlock(ctx context.Context, + t *testing.T, + state sm.State, + signer types.PrivValidator, + height int64, + lc *types.ExtendedCommit) (*types.Block, types.BlockID, *types.PartSet, *types.ExtendedCommit) { + + lastExtCommit := lc.Clone() + + block := sf.MakeBlock(state, height, lastExtCommit.ToCommit()) + partSet, err := block.MakePartSet(types.BlockPartSizeBytes) + require.NoError(t, err) + blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: partSet.Header()} + + // Simulate a commit for the current height + vote, err := factory.MakeVote( + ctx, + signer, + block.Header.ChainID, + 0, + block.Header.Height, + 0, + 2, + blockID, + time.Now(), + ) + require.NoError(t, err) + seenExtCommit := &types.ExtendedCommit{ + Height: vote.Height, + Round: vote.Round, + BlockID: blockID, + ExtendedSignatures: []types.ExtendedCommitSig{vote.ExtendedCommitSig()}, + } + return block, blockID, partSet, seenExtCommit + } func (rts *reactorTestSuite) start(ctx context.Context, t *testing.T) { @@ -415,3 +444,35 @@ func TestReactor_BadBlockStopsPeer(t *testing.T) { len(rts.reactors[newNode.NodeID].pool.peers), ) } + +/* +func TestReactorReceivesNoExtendedCommit(t *testing.T) { + blockDB := dbm.NewMemDB() + stateDB := dbm.NewMemDB() + stateStore := sm.NewStore(stateDB) + blockStore := store.NewBlockStore(blockDB) + blockExec := sm.NewBlockExecutor( + stateStore, + log.NewNopLogger(), + rts.app[nodeID], + mp, + sm.EmptyEvidencePool{}, + blockStore, + eventbus, + sm.NopMetrics(), + ) + NewReactor( + log.NewNopLogger(), + stateStore, + blockExec, + blockStore, + nil, + chCreator, + func(ctx context.Context) *p2p.PeerUpdates { return rts.peerUpdates[nodeID] }, + rts.blockSync, + consensus.NopMetrics(), + nil, // eventbus, can be nil + ) + +} +*/ diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index 1dc92b33c..27fc39d6b 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -527,10 +527,11 @@ func loadPrivValidator(t *testing.T, cfg *config.Config) *privval.FilePV { } type makeStateArgs struct { - config *config.Config - logger log.Logger - validators int - application abci.Application + config *config.Config + consensusParams *types.ConsensusParams + logger log.Logger + validators int + application abci.Application } func makeState(ctx context.Context, t *testing.T, args makeStateArgs) (*State, []*validatorStub) { @@ -551,9 +552,13 @@ func makeState(ctx context.Context, t *testing.T, args makeStateArgs) (*State, [ if args.logger == nil { args.logger = log.NewNopLogger() } + c := factory.ConsensusParams() + if args.consensusParams != nil { + c = args.consensusParams + } state, privVals := makeGenesisState(ctx, t, args.config, genesisStateArgs{ - Params: factory.ConsensusParams(), + Params: c, Validators: validators, }) diff --git a/internal/consensus/msgs.go b/internal/consensus/msgs.go index c59c06a41..1024c24ae 100644 --- a/internal/consensus/msgs.go +++ b/internal/consensus/msgs.go @@ -222,11 +222,7 @@ func (*VoteMessage) TypeTag() string { return "tendermint/Vote" } // ValidateBasic checks whether the vote within the message is well-formed. func (m *VoteMessage) ValidateBasic() error { - // Here we validate votes with vote extensions, since we require vote - // extensions to be sent in precommit messages during consensus. Prevote - // messages should never have vote extensions, and this is also validated - // here. - return m.Vote.ValidateWithExtension() + return m.Vote.ValidateBasic() } // String returns a string representation. diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 501523339..18d5851a4 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -798,13 +798,20 @@ func (r *Reactor) gossipVotesRoutine(ctx context.Context, ps *PeerState, voteCh if blockStoreBase > 0 && prs.Height != 0 && rs.Height >= prs.Height+2 && prs.Height >= blockStoreBase { // Load the block's extended commit for prs.Height, which contains precommit // signatures for prs.Height. - if ec := r.state.blockStore.LoadBlockExtendedCommit(prs.Height); ec != nil { - if ok, err := r.pickSendVote(ctx, ps, ec, voteCh); err != nil { - return - } else if ok { - logger.Debug("picked Catchup commit to send", "height", prs.Height) - continue - } + var ec *types.ExtendedCommit + if r.state.state.ConsensusParams.ABCI.VoteExtensionsEnabled(prs.Height) { + ec = r.state.blockStore.LoadBlockExtendedCommit(prs.Height) + } else { + ec = r.state.blockStore.LoadBlockCommit(prs.Height).WrappedExtendedCommit() + } + if ec == nil { + continue + } + if ok, err := r.pickSendVote(ctx, ps, ec, voteCh); err != nil { + return + } else if ok { + logger.Debug("picked Catchup commit to send", "height", prs.Height) + continue } } diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index c6a8869db..96cf800bd 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -32,6 +32,7 @@ import ( "github.com/tendermint/tendermint/internal/test/factory" "github.com/tendermint/tendermint/libs/log" tmcons "github.com/tendermint/tendermint/proto/tendermint/consensus" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/tendermint/tendermint/types" ) @@ -600,6 +601,118 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { wg.Wait() } +// TestSwitchToConsensusVoteExtensions tests that the SwitchToConsensus correctly +// checks for vote extension data when required. +func TestSwitchToConsensusVoteExtensions(t *testing.T) { + for _, testCase := range []struct { + name string + storedHeight int64 + initialRequiredHeight int64 + includeExtensions bool + shouldPanic bool + }{ + { + name: "no vote extensions but not required", + initialRequiredHeight: 0, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions but required this height", + initialRequiredHeight: 2, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "no vote extensions and required in future", + initialRequiredHeight: 3, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: true, + shouldPanic: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + cs, vs := makeState(ctx, t, makeStateArgs{validators: 1}) + validator := vs[0] + validator.Height = testCase.storedHeight + + cs.state.LastBlockHeight = testCase.storedHeight + cs.state.LastValidators = cs.state.Validators.Copy() + cs.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.initialRequiredHeight + + propBlock, err := cs.createProposalBlock(ctx) + require.NoError(t, err) + + // Consensus is preparing to do the next height after the stored height. + cs.Height = testCase.storedHeight + 1 + propBlock.Height = testCase.storedHeight + blockParts, err := propBlock.MakePartSet(types.BlockPartSizeBytes) + require.NoError(t, err) + + var voteSet *types.VoteSet + if testCase.includeExtensions { + voteSet = types.NewExtendedVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } else { + voteSet = types.NewVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } + signedVote := signVote(ctx, t, validator, tmproto.PrecommitType, cs.state.ChainID, types.BlockID{ + Hash: propBlock.Hash(), + PartSetHeader: blockParts.Header(), + }) + + if !testCase.includeExtensions { + signedVote.Extension = nil + signedVote.ExtensionSignature = nil + } + + added, err := voteSet.AddVote(signedVote) + require.NoError(t, err) + require.True(t, added) + + if testCase.includeExtensions { + cs.blockStore.SaveBlockWithExtendedCommit(propBlock, blockParts, voteSet.MakeExtendedCommit()) + } else { + cs.blockStore.SaveBlock(propBlock, blockParts, voteSet.MakeExtendedCommit().ToCommit()) + } + reactor := NewReactor( + log.NewNopLogger(), + cs, + nil, + nil, + cs.eventBus, + true, + NopMetrics(), + ) + + if testCase.shouldPanic { + assert.Panics(t, func() { + reactor.SwitchToConsensus(ctx, cs.state, false) + }) + } else { + reactor.SwitchToConsensus(ctx, cs.state, false) + } + }) + } +} + func TestReactorRecordsVotesAndBlockParts(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index c8f04655b..99d3c17a1 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -1204,14 +1204,16 @@ func (bs *mockBlockStore) LoadBlockMeta(height int64) *types.BlockMeta { } } func (bs *mockBlockStore) LoadBlockPart(height int64, index int) *types.Part { return nil } -func (bs *mockBlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (bs *mockBlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +} +func (bs *mockBlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { } func (bs *mockBlockStore) LoadBlockCommit(height int64) *types.Commit { - return bs.extCommits[height-1].StripExtensions() + return bs.extCommits[height-1].ToCommit() } func (bs *mockBlockStore) LoadSeenCommit() *types.Commit { - return bs.extCommits[len(bs.extCommits)-1].StripExtensions() + return bs.extCommits[len(bs.extCommits)-1].ToCommit() } func (bs *mockBlockStore) LoadBlockExtendedCommit(height int64) *types.ExtendedCommit { return bs.extCommits[height-1] diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 320042d30..3af775bb6 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -696,23 +696,54 @@ func (cs *State) sendInternalMessage(ctx context.Context, mi msgInfo) { } } -// Reconstruct LastCommit from SeenCommit, which we saved along with the block, -// (which happens even before saving the state) +// Reconstruct the LastCommit from either SeenCommit or the ExtendedCommit. SeenCommit +// and ExtendedCommit are saved along with the block. If VoteExtensions are required +// the method will panic on an absent ExtendedCommit or an ExtendedCommit without +// extension data. func (cs *State) reconstructLastCommit(state sm.State) { - extCommit := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) - if extCommit == nil { - panic(fmt.Sprintf( - "failed to reconstruct last commit; commit for height %v not found", - state.LastBlockHeight, - )) + extensionsEnabled := cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(state.LastBlockHeight) + if !extensionsEnabled { + votes, err := cs.votesFromSeenCommit(state) + if err != nil { + panic(fmt.Sprintf("failed to reconstruct last commit; %s", err)) + } + cs.LastCommit = votes + return } - lastPrecommits := extCommit.ToVoteSet(state.ChainID, state.LastValidators) - if !lastPrecommits.HasTwoThirdsMajority() { - panic("failed to reconstruct last commit; does not have +2/3 maj") + votes, err := cs.votesFromExtendedCommit(state) + if err != nil { + panic(fmt.Sprintf("failed to reconstruct last extended commit; %s", err)) + } + cs.LastCommit = votes +} + +func (cs *State) votesFromExtendedCommit(state sm.State) (*types.VoteSet, error) { + ec := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) + if ec == nil { + return nil, fmt.Errorf("extended commit for height %v not found", state.LastBlockHeight) + } + vs := ec.ToExtendedVoteSet(state.ChainID, state.LastValidators) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("extended commit does not have +2/3 majority") + } + return vs, nil +} + +func (cs *State) votesFromSeenCommit(state sm.State) (*types.VoteSet, error) { + commit := cs.blockStore.LoadSeenCommit() + if commit == nil || commit.Height != state.LastBlockHeight { + commit = cs.blockStore.LoadBlockCommit(state.LastBlockHeight) + } + if commit == nil { + return nil, fmt.Errorf("commit for height %v not found", state.LastBlockHeight) } - cs.LastCommit = lastPrecommits + vs := commit.ToVoteSet(state.ChainID, state.LastValidators) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("commit does not have +2/3 majority") + } + return vs, nil } // Updates State and increments height to match that of state. @@ -814,7 +845,11 @@ func (cs *State) updateToState(state sm.State) { cs.ValidRound = -1 cs.ValidBlock = nil cs.ValidBlockParts = nil - cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators) + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(height) { + cs.Votes = cstypes.NewExtendedHeightVoteSet(state.ChainID, height, validators) + } else { + cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators) + } cs.CommitRound = -1 cs.LastValidators = state.LastValidators cs.TriggeredTimeoutPrecommit = false @@ -1925,8 +1960,12 @@ func (cs *State) finalizeCommit(ctx context.Context, height int64) { if cs.blockStore.Height() < block.Height { // NOTE: the seenCommit is local justification to commit this block, // but may differ from the LastCommit included in the next block - precommits := cs.Votes.Precommits(cs.CommitRound) - cs.blockStore.SaveBlock(block, blockParts, precommits.MakeExtendedCommit()) + seenExtendedCommit := cs.Votes.Precommits(cs.CommitRound).MakeExtendedCommit() + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(block.Height) { + cs.blockStore.SaveBlockWithExtendedCommit(block, blockParts, seenExtendedCommit) + } else { + cs.blockStore.SaveBlock(block, blockParts, seenExtendedCommit.ToCommit()) + } } else { // Happens during replay if we already saved the block but didn't commit logger.Debug("calling finalizeCommit on already stored block", "height", block.Height) @@ -2341,13 +2380,45 @@ func (cs *State) addVote( return } - // Verify VoteExtension if precommit and not nil - // https://github.com/tendermint/tendermint/issues/8487 - if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() { - err := cs.blockExec.VerifyVoteExtension(ctx, vote) - cs.metrics.MarkVoteExtensionReceived(err == nil) - if err != nil { - return false, err + // Check to see if the chain is configured to extend votes. + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(cs.Height) { + // The chain is configured to extend votes, check that the vote is + // not for a nil block and verify the extensions signature against the + // corresponding public key. + + var myAddr []byte + if cs.privValidatorPubKey != nil { + myAddr = cs.privValidatorPubKey.Address() + } + // Verify VoteExtension if precommit and not nil + // https://github.com/tendermint/tendermint/issues/8487 + if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() && + !bytes.Equal(vote.ValidatorAddress, myAddr) { // Skip the VerifyVoteExtension call if the vote was issued by this validator. + + // The core fields of the vote message were already validated in the + // consensus reactor when the vote was received. + // Here, we verify the signature of the vote extension included in the vote + // message. + _, val := cs.state.Validators.GetByIndex(vote.ValidatorIndex) + if err := vote.VerifyExtension(cs.state.ChainID, val.PubKey); err != nil { + return false, err + } + + err := cs.blockExec.VerifyVoteExtension(ctx, vote) + cs.metrics.MarkVoteExtensionReceived(err == nil) + if err != nil { + return false, err + } + } + } else { + // Vote extensions are not enabled on the network. + // strip the extension data from the vote in case any is present. + // + // TODO punish a peer if it sent a vote with an extension when the feature + // is disabled on the network. + // https://github.com/tendermint/tendermint/issues/8565 + if stripped := vote.StripExtension(); stripped { + cs.logger.Error("vote included extension data but vote extensions are not enabled", "peer", peerID) } } @@ -2496,18 +2567,18 @@ func (cs *State) signVote( // If the signedMessageType is for precommit, // use our local precommit Timeout as the max wait time for getting a singed commit. The same goes for prevote. - timeout := cs.voteTimeout(cs.Round) - + timeout := time.Second if msgType == tmproto.PrecommitType && !vote.BlockID.IsNil() { + timeout = cs.voteTimeout(cs.Round) // if the signedMessage type is for a non-nil precommit, add // VoteExtension - ext, err := cs.blockExec.ExtendVote(ctx, vote) - if err != nil { - return nil, err + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(cs.Height) { + ext, err := cs.blockExec.ExtendVote(ctx, vote) + if err != nil { + return nil, err + } + vote.Extension = ext } - vote.Extension = ext - } else { - timeout = time.Second } v := vote.ToProto() @@ -2547,14 +2618,17 @@ func (cs *State) signAddVote( // TODO: pass pubKey to signVote vote, err := cs.signVote(ctx, msgType, hash, header) - if err == nil { - cs.sendInternalMessage(ctx, msgInfo{&VoteMessage{vote}, "", tmtime.Now()}) - cs.logger.Debug("signed and pushed vote", "height", cs.Height, "round", cs.Round, "vote", vote) - return vote + if err != nil { + cs.logger.Error("failed signing vote", "height", cs.Height, "round", cs.Round, "vote", vote, "err", err) + return nil } - - cs.logger.Error("failed signing vote", "height", cs.Height, "round", cs.Round, "vote", vote, "err", err) - return nil + if !cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(vote.Height) { + // The signer will sign the extension, make sure to remove the data on the way out + vote.StripExtension() + } + cs.sendInternalMessage(ctx, msgInfo{&VoteMessage{vote}, "", tmtime.Now()}) + cs.logger.Debug("signed and pushed vote", "height", cs.Height, "round", cs.Round, "vote", vote) + return vote } // updatePrivValidatorPubKey get's the private validator public key and diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 6fa69a1a3..797dcf59c 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -18,6 +18,7 @@ import ( "github.com/tendermint/tendermint/internal/eventbus" tmpubsub "github.com/tendermint/tendermint/internal/pubsub" tmquery "github.com/tendermint/tendermint/internal/pubsub/query" + "github.com/tendermint/tendermint/internal/test/factory" tmbytes "github.com/tendermint/tendermint/libs/bytes" "github.com/tendermint/tendermint/libs/log" tmrand "github.com/tendermint/tendermint/libs/rand" @@ -2026,77 +2027,101 @@ func TestFinalizeBlockCalled(t *testing.T) { } } -// TestExtendVoteCalled tests that the vote extension methods are called at the -// correct point in the consensus algorithm. -func TestExtendVoteCalled(t *testing.T) { - config := configSetup(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +// TestExtendVoteCalledWhenEnabled tests that the vote extension methods are called at the +// correct point in the consensus algorithm when vote extensions are enabled. +func TestExtendVoteCalledWhenEnabled(t *testing.T) { + for _, testCase := range []struct { + name string + enabled bool + }{ + { + name: "enabled", + enabled: true, + }, + { + name: "disabled", + enabled: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - m := abcimocks.NewApplication(t) - m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) - m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) - m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ - VoteExtension: []byte("extension"), - }, nil) - m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ - Status: abci.ResponseVerifyVoteExtension_ACCEPT, - }, nil) - m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() - m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() - cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) - height, round := cs1.Height, cs1.Round + m := abcimocks.NewApplication(t) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + if testCase.enabled { + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ + VoteExtension: []byte("extension"), + }, nil) + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil) + } + m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + c := factory.ConsensusParams() + if !testCase.enabled { + c.ABCI.VoteExtensionsEnableHeight = 0 + } + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m, consensusParams: c}) + height, round := cs1.Height, cs1.Round - proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey(ctx) - require.NoError(t, err) - addr := pv1.Address() - voteCh := subscribeToVoter(ctx, t, cs1, addr) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) - startTestRound(ctx, cs1, cs1.Height, round) - ensureNewRound(t, newRoundCh, height, round) - ensureNewProposal(t, proposalCh, height, round) + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) - m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) + m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) - rs := cs1.GetRoundState() + rs := cs1.GetRoundState() - blockID := types.BlockID{ - Hash: rs.ProposalBlock.Hash(), - PartSetHeader: rs.ProposalBlockParts.Header(), - } - signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) - ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) - ensurePrecommit(t, voteCh, height, round) + ensurePrecommit(t, voteCh, height, round) - m.AssertCalled(t, "ExtendVote", ctx, &abci.RequestExtendVote{ - Height: height, - Hash: blockID.Hash, - }) + if testCase.enabled { + m.AssertCalled(t, "ExtendVote", ctx, &abci.RequestExtendVote{ + Height: height, + Hash: blockID.Hash, + }) + } else { + m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) + } - m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) - signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) - ensureNewRound(t, newRoundCh, height+1, 0) - m.AssertExpectations(t) + signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[1:]...) + ensureNewRound(t, newRoundCh, height+1, 0) + m.AssertExpectations(t) - // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold - // is observed by the consensus engine. - for _, pv := range vss[:3] { - pv, err := pv.GetPubKey(ctx) - require.NoError(t, err) - addr := pv.Address() - m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), + // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold + // is observed by the consensus engine. + for _, pv := range vss[1:3] { + pv, err := pv.GetPubKey(ctx) + require.NoError(t, err) + addr := pv.Address() + if testCase.enabled { + m.AssertCalled(t, "VerifyVoteExtension", ctx, &abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + } else { + m.AssertNotCalled(t, "VerifyVoteExtension", mock.Anything, mock.Anything) + } + } }) } @@ -2121,6 +2146,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m}) height, round := cs1.Height, cs1.Round + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = cs1.Height proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) @@ -2138,7 +2164,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { Hash: rs.ProposalBlock.Hash(), PartSetHeader: rs.ProposalBlockParts.Header(), } - signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[2:]...) + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss...) ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) ensurePrecommit(t, voteCh, height, round) @@ -2148,13 +2174,6 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { Hash: blockID.Hash, }) - m.AssertCalled(t, "VerifyVoteExtension", mock.Anything, &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) - m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() signAddVotes(ctx, t, cs1, tmproto.PrecommitType, config.ChainID(), blockID, vss[2:]...) ensureNewRound(t, newRoundCh, height+1, 0) @@ -2266,6 +2285,134 @@ func TestPrepareProposalReceivesVoteExtensions(t *testing.T) { } } +// TestVoteExtensionEnableHeight tests that 'ExtensionRequireHeight' correctly +// enforces that vote extensions be present in consensus for heights greater than +// or equal to the configured value. +func TestVoteExtensionEnableHeight(t *testing.T) { + for _, testCase := range []struct { + name string + enableHeight int64 + hasExtension bool + expectExtendCalled bool + expectVerifyCalled bool + expectSuccessfulRound bool + }{ + { + name: "extension present but not enabled", + hasExtension: true, + enableHeight: 0, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + { + name: "extension absent but not required", + hasExtension: false, + enableHeight: 0, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + { + name: "extension present and required", + hasExtension: true, + enableHeight: 1, + expectExtendCalled: true, + expectVerifyCalled: true, + expectSuccessfulRound: true, + }, + { + name: "extension absent but required", + hasExtension: false, + enableHeight: 1, + expectExtendCalled: true, + expectVerifyCalled: false, + expectSuccessfulRound: false, + }, + { + name: "extension absent but required in future height", + hasExtension: false, + enableHeight: 2, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + config := configSetup(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + numValidators := 3 + m := abcimocks.NewApplication(t) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{ + Status: abci.ResponseProcessProposal_ACCEPT, + }, nil) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + if testCase.expectExtendCalled { + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{}, nil) + } + if testCase.expectVerifyCalled { + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil).Times(numValidators - 1) + } + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + m.On("Commit", mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + c := factory.ConsensusParams() + c.ABCI.VoteExtensionsEnableHeight = testCase.enableHeight + cs1, vss := makeState(ctx, t, makeStateArgs{config: config, application: m, validators: numValidators, consensusParams: c}) + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.enableHeight + height, round := cs1.Height, cs1.Round + + timeoutCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryTimeoutPropose) + proposalCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(ctx, t, cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey(ctx) + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(ctx, t, cs1, addr) + + startTestRound(ctx, cs1, cs1.Height, round) + ensureNewRound(t, newRoundCh, height, round) + ensureNewProposal(t, proposalCh, height, round) + rs := cs1.GetRoundState() + + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + + // sign all of the votes + signAddVotes(ctx, t, cs1, tmproto.PrevoteType, config.ChainID(), blockID, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, rs.ProposalBlock.Hash()) + + var ext []byte + if testCase.hasExtension { + ext = []byte("extension") + } + + for _, vs := range vss[1:] { + vote, err := vs.signVote(ctx, tmproto.PrecommitType, config.ChainID(), blockID, ext) + if !testCase.hasExtension { + vote.ExtensionSignature = nil + } + require.NoError(t, err) + addVotes(cs1, vote) + } + if testCase.expectSuccessfulRound { + ensurePrecommit(t, voteCh, height, round) + height++ + ensureNewRound(t, newRoundCh, height, round) + } else { + ensureNoNewTimeout(t, timeoutCh, cs1.state.ConsensusParams.Timeout.VoteTimeout(round).Nanoseconds()) + } + + m.AssertExpectations(t) + }) + } +} + // 4 vals, 3 Nil Precommits at P0 // What we want: // P0 waits for timeoutPrecommit before starting next round diff --git a/internal/consensus/types/height_vote_set.go b/internal/consensus/types/height_vote_set.go index 661c5120e..389c02356 100644 --- a/internal/consensus/types/height_vote_set.go +++ b/internal/consensus/types/height_vote_set.go @@ -38,9 +38,10 @@ We let each peer provide us with up to 2 unexpected "catchup" rounds. One for their LastCommit round, and another for the official commit round. */ type HeightVoteSet struct { - chainID string - height int64 - valSet *types.ValidatorSet + chainID string + height int64 + valSet *types.ValidatorSet + extensionsEnabled bool mtx sync.Mutex round int32 // max tracked round @@ -50,7 +51,17 @@ type HeightVoteSet struct { func NewHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { hvs := &HeightVoteSet{ - chainID: chainID, + chainID: chainID, + extensionsEnabled: false, + } + hvs.Reset(height, valSet) + return hvs +} + +func NewExtendedHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { + hvs := &HeightVoteSet{ + chainID: chainID, + extensionsEnabled: true, } hvs.Reset(height, valSet) return hvs @@ -108,7 +119,12 @@ func (hvs *HeightVoteSet) addRound(round int32) { } // log.Debug("addRound(round)", "round", round) prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet) - precommits := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + var precommits *types.VoteSet + if hvs.extensionsEnabled { + precommits = types.NewExtendedVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } else { + precommits = types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } hvs.roundVoteSets[round] = RoundVoteSet{ Prevotes: prevotes, Precommits: precommits, diff --git a/internal/consensus/types/height_vote_set_test.go b/internal/consensus/types/height_vote_set_test.go index acffa794c..a2cfd84ec 100644 --- a/internal/consensus/types/height_vote_set_test.go +++ b/internal/consensus/types/height_vote_set_test.go @@ -27,7 +27,7 @@ func TestPeerCatchupRounds(t *testing.T) { valSet, privVals := factory.ValidatorSet(ctx, t, 10, 1) chainID := cfg.ChainID() - hvs := NewHeightVoteSet(chainID, 1, valSet) + hvs := NewExtendedHeightVoteSet(chainID, 1, valSet) vote999_0 := makeVoteHR(ctx, t, 1, 0, 999, privVals, chainID) added, err := hvs.AddVote(vote999_0, "peer1") diff --git a/internal/evidence/pool_test.go b/internal/evidence/pool_test.go index 4047d3e7f..f80d13c7b 100644 --- a/internal/evidence/pool_test.go +++ b/internal/evidence/pool_test.go @@ -250,7 +250,7 @@ func TestEvidencePoolUpdate(t *testing.T) { ) require.NoError(t, err) lastExtCommit := makeExtCommit(height, val.PrivKey.PubKey().Address()) - block := types.MakeBlock(height+1, []types.Tx{}, lastExtCommit.StripExtensions(), []types.Evidence{ev}) + block := types.MakeBlock(height+1, []types.Tx{}, lastExtCommit.ToCommit(), []types.Evidence{ev}) // update state (partially) state.LastBlockHeight = height + 1 @@ -569,7 +569,7 @@ func initializeBlockStore(db dbm.DB, state sm.State, valAddr []byte) (*store.Blo for i := int64(1); i <= state.LastBlockHeight; i++ { lastCommit := makeExtCommit(i-1, valAddr) - block := sf.MakeBlock(state, i, lastCommit.StripExtensions()) + block := sf.MakeBlock(state, i, lastCommit.ToCommit()) block.Header.Time = defaultEvidenceTime.Add(time.Duration(i) * time.Minute) block.Header.Version = version.Consensus{Block: version.BlockProtocol, App: 1} @@ -580,7 +580,7 @@ func initializeBlockStore(db dbm.DB, state sm.State, valAddr []byte) (*store.Blo } seenCommit := makeExtCommit(i, valAddr) - blockStore.SaveBlock(block, partSet, seenCommit) + blockStore.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } return blockStore, nil @@ -596,6 +596,7 @@ func makeExtCommit(height int64, valAddr []byte) *types.ExtendedCommit { Timestamp: defaultEvidenceTime, Signature: []byte("Signature"), }, + ExtensionSignature: []byte("Extended Signature"), }}, } } diff --git a/internal/evidence/verify_test.go b/internal/evidence/verify_test.go index 2ed84fa69..6341fde1d 100644 --- a/internal/evidence/verify_test.go +++ b/internal/evidence/verify_test.go @@ -233,10 +233,10 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { // we are simulating a duplicate vote attack where all the validators in the conflictingVals set // except the last validator vote twice blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, 10, 1, voteSet, conflictingPrivVals[:4], defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ @@ -253,11 +253,11 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, 10, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trustedSignedHeader := &types.SignedHeader{ Header: trustedHeader, @@ -336,10 +336,10 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { // we are simulating an amnesia attack where all the validators in the conflictingVals set // except the last validator vote twice. However this time the commits are of different rounds. blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, height, 0, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 0, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ @@ -356,11 +356,11 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trustedSignedHeader := &types.SignedHeader{ Header: trustedHeader, @@ -553,10 +553,10 @@ func makeLunaticEvidence( }) blockID := factory.MakeBlockIDWithHash(conflictingHeader.Hash()) - voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := factory.MakeExtendedCommit(ctx, blockID, height, 1, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev = &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ @@ -582,10 +582,10 @@ func makeLunaticEvidence( } trustedBlockID := factory.MakeBlockIDWithHash(trustedHeader.Hash()) trustedVals, privVals := factory.ValidatorSet(ctx, t, totalVals, defaultVotingPower) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) trustedExtCommit, err := factory.MakeExtendedCommit(ctx, trustedBlockID, height, 1, trustedVoteSet, privVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trusted = &types.LightBlock{ SignedHeader: &types.SignedHeader{ diff --git a/internal/state/execution.go b/internal/state/execution.go index 2c88c793b..47c2cb7ae 100644 --- a/internal/state/execution.go +++ b/internal/state/execution.go @@ -3,6 +3,7 @@ package state import ( "bytes" "context" + "errors" "fmt" "time" @@ -100,15 +101,14 @@ func (blockExec *BlockExecutor) CreateProposalBlock( maxDataBytes := types.MaxDataBytes(maxBytes, evSize, state.Validators.Size()) txs := blockExec.mempool.ReapMaxBytesMaxGas(maxDataBytes, maxGas) - commit := lastExtCommit.StripExtensions() + commit := lastExtCommit.ToCommit() block := state.MakeBlock(height, txs, commit, evidence, proposerAddr) - rpp, err := blockExec.appClient.PrepareProposal( ctx, &abci.RequestPrepareProposal{ MaxTxBytes: maxDataBytes, Txs: block.Txs.ToSliceOfBytes(), - LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight), + LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight, state.ConsensusParams.ABCI), ByzantineValidators: block.Evidence.ToABCI(), Height: block.Height, Time: block.Time, @@ -321,7 +321,7 @@ func (blockExec *BlockExecutor) VerifyVoteExtension(ctx context.Context, vote *t } if !resp.IsOK() { - return types.ErrVoteInvalidExtension + return errors.New("invalid vote extension") } return nil @@ -428,7 +428,7 @@ func buildLastCommitInfo(block *types.Block, store Store, initialHeight int64) a // data, it returns an empty record. // // Assumes that the commit signatures are sorted according to validator index. -func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64) abci.ExtendedCommitInfo { +func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64, ap types.ABCIParams) abci.ExtendedCommitInfo { if ec.Height < initialHeight { // There are no extended commits for heights below the initial height. return abci.ExtendedCommitInfo{} @@ -466,9 +466,15 @@ func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeigh } var ext []byte - if ecs.BlockIDFlag == types.BlockIDFlagCommit { - // We only care about vote extensions if a validator has voted to - // commit. + // Check if vote extensions were enabled during the commit's height: ec.Height. + // ec is the commit from the previous height, so if extensions were enabled + // during that height, we ensure they are present and deliver the data to + // the proposer. If they were not enabled during this previous height, we + // will not deliver extension data. + if ap.VoteExtensionsEnabled(ec.Height) && ecs.BlockIDFlag == types.BlockIDFlagCommit { + if err := ecs.EnsureExtension(); err != nil { + panic(fmt.Errorf("commit at height %d received with missing vote extensions data", ec.Height)) + } ext = ecs.Extension } diff --git a/internal/state/execution_test.go b/internal/state/execution_test.go index ffe9cb6f8..5fb4dc297 100644 --- a/internal/state/execution_test.go +++ b/internal/state/execution_test.go @@ -140,7 +140,7 @@ func TestFinalizeBlockDecidedLastCommit(t *testing.T) { } // block for height 2 - block := sf.MakeBlock(state, 2, lastCommit.StripExtensions()) + block := sf.MakeBlock(state, 2, lastCommit.ToCommit()) bps, err := block.MakePartSet(testPartSize) require.NoError(t, err) blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: bps.Header()} @@ -1004,6 +1004,116 @@ func TestPrepareProposalErrorOnPrepareProposalError(t *testing.T) { mp.AssertExpectations(t) } +// TestCreateProposalBlockPanicOnAbsentVoteExtensions ensures that the CreateProposalBlock +// call correctly panics when the vote extension data is missing from the extended commit +// data that the method receives. +func TestCreateProposalAbsentVoteExtensions(t *testing.T) { + for _, testCase := range []struct { + name string + + // The height that is about to be proposed + height int64 + + // The first height during which vote extensions will be required for consensus to proceed. + extensionEnableHeight int64 + expectPanic bool + }{ + { + name: "missing extension data on first required height", + height: 2, + extensionEnableHeight: 1, + expectPanic: true, + }, + { + name: "missing extension during before required height", + height: 2, + extensionEnableHeight: 2, + expectPanic: false, + }, + { + name: "missing extension data and not required", + height: 2, + extensionEnableHeight: 0, + expectPanic: false, + }, + { + name: "missing extension data and required in two heights", + height: 2, + extensionEnableHeight: 3, + expectPanic: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewNopLogger() + + eventBus := eventbus.NewDefault(logger) + require.NoError(t, eventBus.Start(ctx)) + + app := abcimocks.NewApplication(t) + if !testCase.expectPanic { + app.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + } + cc := abciclient.NewLocalClient(logger, app) + proxyApp := proxy.New(cc, logger, proxy.NopMetrics()) + err := proxyApp.Start(ctx) + require.NoError(t, err) + + state, stateDB, privVals := makeState(t, 1, int(testCase.height-1)) + stateStore := sm.NewStore(stateDB) + state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.extensionEnableHeight + mp := &mpmocks.Mempool{} + mp.On("Lock").Return() + mp.On("Unlock").Return() + mp.On("FlushAppConn", mock.Anything).Return(nil) + mp.On("Update", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything).Return(nil) + mp.On("ReapMaxBytesMaxGas", mock.Anything, mock.Anything).Return(types.Txs{}) + + blockExec := sm.NewBlockExecutor( + stateStore, + logger, + proxyApp, + mp, + sm.EmptyEvidencePool{}, + nil, + eventBus, + sm.NopMetrics(), + ) + block := sf.MakeBlock(state, testCase.height, new(types.Commit)) + bps, err := block.MakePartSet(testPartSize) + require.NoError(t, err) + blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: bps.Header()} + pa, _ := state.Validators.GetByIndex(0) + lastCommit, _ := makeValidCommit(ctx, t, testCase.height-1, blockID, state.Validators, privVals) + stripSignatures(lastCommit) + if testCase.expectPanic { + require.Panics(t, func() { + blockExec.CreateProposalBlock(ctx, testCase.height, state, lastCommit, pa) //nolint:errcheck + }) + } else { + _, err = blockExec.CreateProposalBlock(ctx, testCase.height, state, lastCommit, pa) + require.NoError(t, err) + } + }) + } +} + +func stripSignatures(ec *types.ExtendedCommit) { + for i, commitSig := range ec.ExtendedSignatures { + commitSig.Extension = nil + commitSig.ExtensionSignature = nil + ec.ExtendedSignatures[i] = commitSig + } +} + func makeBlockID(hash []byte, partSetSize uint32, partSetHash []byte) types.BlockID { var ( h = make([]byte, crypto.HashSize) diff --git a/internal/state/mocks/block_store.go b/internal/state/mocks/block_store.go index 4eafb1273..58fc640fc 100644 --- a/internal/state/mocks/block_store.go +++ b/internal/state/mocks/block_store.go @@ -209,7 +209,12 @@ func (_m *BlockStore) PruneBlocks(height int64) (uint64, error) { } // SaveBlock provides a mock function with given fields: block, blockParts, seenCommit -func (_m *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (_m *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { + _m.Called(block, blockParts, seenCommit) +} + +// SaveBlockWithExtendedCommit provides a mock function with given fields: block, blockParts, seenCommit +func (_m *BlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { _m.Called(block, blockParts, seenCommit) } diff --git a/internal/state/services.go b/internal/state/services.go index 35a91aa11..f86c4e3cf 100644 --- a/internal/state/services.go +++ b/internal/state/services.go @@ -26,7 +26,8 @@ type BlockStore interface { LoadBlockMeta(height int64) *types.BlockMeta LoadBlock(height int64) *types.Block - SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) + SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) + SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) PruneBlocks(height int64) (uint64, error) diff --git a/internal/state/store.go b/internal/state/store.go index 2d2e4dc81..d592016f6 100644 --- a/internal/state/store.go +++ b/internal/state/store.go @@ -2,6 +2,7 @@ package state import ( "bytes" + "encoding/binary" "errors" "fmt" @@ -59,6 +60,7 @@ func abciResponsesKey(height int64) []byte { // stateKey should never change after being set in init() var stateKey []byte +var tmpABCIKey []byte func init() { var err error @@ -66,6 +68,12 @@ func init() { if err != nil { panic(err) } + // temporary extra key before consensus param protos are regenerated + // TODO(wbanfield) remove in next PR + tmpABCIKey, err = orderedcode.Append(nil, int64(10000)) + if err != nil { + panic(err) + } } //---------------------- @@ -137,6 +145,12 @@ func (store dbStore) loadState(key []byte) (state State, err error) { if err != nil { return state, err } + buf, err = store.db.Get(tmpABCIKey) + if err != nil { + return state, err + } + h, _ := binary.Varint(buf) + sm.ConsensusParams.ABCI.VoteExtensionsEnableHeight = h return *sm, nil } @@ -181,6 +195,11 @@ func (store dbStore) save(state State, key []byte) error { if err := batch.Set(key, stateBz); err != nil { return err } + bz := make([]byte, 5) + binary.PutVarint(bz, state.ConsensusParams.ABCI.VoteExtensionsEnableHeight) + if err := batch.Set(tmpABCIKey, bz); err != nil { + return err + } return batch.WriteSync() } diff --git a/internal/state/validation_test.go b/internal/state/validation_test.go index b29cfd0f9..0f43db5eb 100644 --- a/internal/state/validation_test.go +++ b/internal/state/validation_test.go @@ -124,7 +124,7 @@ func TestValidateBlockHeader(t *testing.T) { */ state, _, lastExtCommit = makeAndCommitGoodBlock(ctx, t, state, height, lastCommit, state.Validators.GetProposer().Address, blockExec, privVals, nil) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() } nextHeight := validationTestsStopHeight @@ -234,7 +234,7 @@ func TestValidateBlockCommit(t *testing.T) { privVals, nil, ) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() /* wrongSigsCommit is fine except for the extra bad precommit @@ -384,7 +384,7 @@ func TestValidateBlockEvidence(t *testing.T) { privVals, evidence, ) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() } } diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 904fb2b74..f57e228a7 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -855,13 +855,13 @@ func mockLB(ctx context.Context, t *testing.T, height int64, time time.Time, las header.NextValidatorsHash = nextVals.Hash() header.ConsensusHash = types.DefaultConsensusParams().HashConsensusParams() lastBlockID = factory.MakeBlockIDWithHash(header.Hash()) - voteSet := types.NewVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals) + voteSet := types.NewExtendedVoteSet(factory.DefaultTestChainID, height, 0, tmproto.PrecommitType, currentVals) extCommit, err := factory.MakeExtendedCommit(ctx, lastBlockID, height, 0, voteSet, currentPrivVals, time) require.NoError(t, err) return nextVals, nextPrivVals, &types.LightBlock{ SignedHeader: &types.SignedHeader{ Header: header, - Commit: extCommit.StripExtensions(), + Commit: extCommit.ToCommit(), }, ValidatorSet: currentVals, } diff --git a/internal/store/store.go b/internal/store/store.go index 5617674a2..1ba7e398d 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -2,6 +2,7 @@ package store import ( "bytes" + "errors" "fmt" "strconv" @@ -278,6 +279,9 @@ func (bs *BlockStore) LoadBlockCommit(height int64) *types.Commit { return commit } +// LoadExtendedCommit returns the ExtendedCommit for the given height. +// The extended commit is not guaranteed to contain the same +2/3 precommits data +// as the commit in the block. func (bs *BlockStore) LoadBlockExtendedCommit(height int64) *types.ExtendedCommit { pbec := new(tmproto.ExtendedCommit) bz, err := bs.db.Get(extCommitKey(height)) @@ -466,25 +470,73 @@ func (bs *BlockStore) batchDelete( // If all the nodes restart after committing a block, // we need this to reload the precommits to catch-up nodes to the // most recent height. Otherwise they'd stall at H-1. -func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { if block == nil { panic("BlockStore can only save a non-nil block") } - batch := bs.db.NewBatch() + if err := bs.saveBlockToBatch(batch, block, blockParts, seenCommit); err != nil { + panic(err) + } + + if err := batch.WriteSync(); err != nil { + panic(err) + } + + if err := batch.Close(); err != nil { + panic(err) + } +} + +// SaveBlockWithExtendedCommit persists the given block, blockParts, and +// seenExtendedCommit to the underlying db. seenExtendedCommit is stored under +// two keys in the database: as the seenCommit and as the ExtendedCommit data for the +// height. This allows the vote extension data to be persisted for all blocks +// that are saved. +func (bs *BlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenExtendedCommit *types.ExtendedCommit) { + if block == nil { + panic("BlockStore can only save a non-nil block") + } + if err := seenExtendedCommit.EnsureExtensions(); err != nil { + panic(fmt.Errorf("saving block with extensions: %w", err)) + } + batch := bs.db.NewBatch() + if err := bs.saveBlockToBatch(batch, block, blockParts, seenExtendedCommit.ToCommit()); err != nil { + panic(err) + } + height := block.Height + + pbec := seenExtendedCommit.ToProto() + extCommitBytes := mustEncode(pbec) + if err := batch.Set(extCommitKey(height), extCommitBytes); err != nil { + panic(err) + } + + if err := batch.WriteSync(); err != nil { + panic(err) + } + + if err := batch.Close(); err != nil { + panic(err) + } +} + +func (bs *BlockStore) saveBlockToBatch(batch dbm.Batch, block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) error { + if block == nil { + panic("BlockStore can only save a non-nil block") + } height := block.Height hash := block.Hash() if g, w := height, bs.Height()+1; bs.Base() > 0 && g != w { - panic(fmt.Sprintf("BlockStore can only save contiguous blocks. Wanted %v, got %v", w, g)) + return fmt.Errorf("BlockStore can only save contiguous blocks. Wanted %v, got %v", w, g) } if !blockParts.IsComplete() { - panic("BlockStore can only save complete block part sets") + return errors.New("BlockStore can only save complete block part sets") } if height != seenCommit.Height { - panic(fmt.Sprintf("BlockStore cannot save seen commit of a different height (block: %d, commit: %d)", - height, seenCommit.Height)) + return fmt.Errorf("BlockStore cannot save seen commit of a different height (block: %d, commit: %d)", height, seenCommit.Height) } // Save block parts. This must be done before the block meta, since callers @@ -499,44 +551,32 @@ func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, s blockMeta := types.NewBlockMeta(block, blockParts) pbm := blockMeta.ToProto() if pbm == nil { - panic("nil blockmeta") + return errors.New("nil blockmeta") } metaBytes := mustEncode(pbm) if err := batch.Set(blockMetaKey(height), metaBytes); err != nil { - panic(err) + return err } if err := batch.Set(blockHashKey(hash), []byte(fmt.Sprintf("%d", height))); err != nil { - panic(err) + return err } pbc := block.LastCommit.ToProto() blockCommitBytes := mustEncode(pbc) if err := batch.Set(blockCommitKey(height-1), blockCommitBytes); err != nil { - panic(err) + return err } // Save seen commit (seen +2/3 precommits for block) - pbsc := seenCommit.StripExtensions().ToProto() + pbsc := seenCommit.ToProto() seenCommitBytes := mustEncode(pbsc) if err := batch.Set(seenCommitKey(), seenCommitBytes); err != nil { - panic(err) + return err } - pbec := seenCommit.ToProto() - extCommitBytes := mustEncode(pbec) - if err := batch.Set(extCommitKey(height), extCommitBytes); err != nil { - panic(err) - } - - if err := batch.WriteSync(); err != nil { - panic(err) - } - - if err := batch.Close(); err != nil { - panic(err) - } + return nil } func (bs *BlockStore) saveBlockPart(height int64, index int, part *types.Part, batch dbm.Batch) { diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 9df3eed9f..771129cc0 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -36,6 +36,7 @@ func makeTestExtCommit(height int64, timestamp time.Time) *types.ExtendedCommit Timestamp: timestamp, Signature: []byte("Signature"), }, + ExtensionSignature: []byte("ExtensionSignature"), }} return &types.ExtendedCommit{ Height: height, @@ -89,7 +90,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { part2 := validPartSet.GetPart(1) seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) - bs.SaveBlock(block, validPartSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, validPartSet, seenCommit) require.EqualValues(t, 1, bs.Base(), "expecting the new height to be changed") require.EqualValues(t, block.Header.Height, bs.Height(), "expecting the new height to be changed") @@ -107,7 +108,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { } // End of setup, test data - commitAtH10 := makeTestExtCommit(10, tmtime.Now()).StripExtensions() + commitAtH10 := makeTestExtCommit(10, tmtime.Now()).ToCommit() tuples := []struct { block *types.Block parts *types.PartSet @@ -140,16 +141,17 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { ChainID: "block_test", Time: tmtime.Now(), ProposerAddress: tmrand.Bytes(crypto.AddressSize)}, - makeTestExtCommit(5, tmtime.Now()).StripExtensions(), + makeTestExtCommit(5, tmtime.Now()).ToCommit(), ), parts: validPartSet, seenCommit: makeTestExtCommit(5, tmtime.Now()), }, { - block: newBlock(header1, commitAtH10), - parts: incompletePartSet, - wantPanic: "only save complete block", // incomplete parts + block: newBlock(header1, commitAtH10), + parts: incompletePartSet, + wantPanic: "only save complete block", // incomplete parts + seenCommit: makeTestExtCommit(10, tmtime.Now()), }, { @@ -178,7 +180,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { }, { - block: newBlock(header1, commitAtH10), + block: block, parts: validPartSet, seenCommit: seenCommit, @@ -187,7 +189,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { }, { - block: newBlock(header1, commitAtH10), + block: block, parts: validPartSet, seenCommit: seenCommit, @@ -209,7 +211,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { bs, db := newInMemoryBlockStore() // SaveBlock res, err, panicErr := doFn(func() (interface{}, error) { - bs.SaveBlock(tuple.block, tuple.parts, tuple.seenCommit) + bs.SaveBlockWithExtendedCommit(tuple.block, tuple.parts, tuple.seenCommit) if tuple.block == nil { return nil, nil } @@ -279,6 +281,90 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { } } +// TestSaveBlockWithExtendedCommitPanicOnAbsentExtension tests that saving a +// block with an extended commit panics when the extension data is absent. +func TestSaveBlockWithExtendedCommitPanicOnAbsentExtension(t *testing.T) { + for _, testCase := range []struct { + name string + malleateCommit func(*types.ExtendedCommit) + shouldPanic bool + }{ + { + name: "basic save", + malleateCommit: func(_ *types.ExtendedCommit) {}, + shouldPanic: false, + }, + { + name: "save commit with no extensions", + malleateCommit: func(c *types.ExtendedCommit) { + c.StripExtensions() + }, + shouldPanic: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + state, bs, cleanup, err := makeStateAndBlockStore(t.TempDir()) + require.NoError(t, err) + defer cleanup() + block := factory.MakeBlock(state, bs.Height()+1, new(types.Commit)) + seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) + ps, err := block.MakePartSet(2) + require.NoError(t, err) + testCase.malleateCommit(seenCommit) + if testCase.shouldPanic { + require.Panics(t, func() { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + }) + } else { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + } + }) + } +} + +// TestLoadBlockExtendedCommit tests loading the extended commit for a previously +// saved block. The load method should return nil when only a commit was saved and +// return the extended commit otherwise. +func TestLoadBlockExtendedCommit(t *testing.T) { + for _, testCase := range []struct { + name string + saveExtended bool + expectResult bool + }{ + { + name: "save commit", + saveExtended: false, + expectResult: false, + }, + { + name: "save extended commit", + saveExtended: true, + expectResult: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + state, bs, cleanup, err := makeStateAndBlockStore(t.TempDir()) + require.NoError(t, err) + defer cleanup() + block := factory.MakeBlock(state, bs.Height()+1, new(types.Commit)) + seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) + ps, err := block.MakePartSet(2) + require.NoError(t, err) + if testCase.saveExtended { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + } else { + bs.SaveBlock(block, ps, seenCommit.ToCommit()) + } + res := bs.LoadBlockExtendedCommit(block.Height) + if testCase.expectResult { + require.Equal(t, seenCommit, res) + } else { + require.Nil(t, res) + } + }) + } +} + func TestLoadBaseMeta(t *testing.T) { cfg, err := config.ResetTestRoot(t.TempDir(), "blockchain_reactor_test") require.NoError(t, err) @@ -293,7 +379,7 @@ func TestLoadBaseMeta(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(h, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } pruned, err := bs.PruneBlocks(4) @@ -371,7 +457,7 @@ func TestPruneBlocks(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(h, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } assert.EqualValues(t, 1, bs.Base()) @@ -479,7 +565,7 @@ func TestBlockFetchAtHeight(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) require.Equal(t, bs.Height(), block.Header.Height, "expecting the new height to be changed") blockAtHeight := bs.LoadBlock(bs.Height()) @@ -518,16 +604,16 @@ func TestSeenAndCanonicalCommit(t *testing.T) { // produce a few blocks and check that the correct seen and cannoncial commits // are persisted. for h := int64(3); h <= 5; h++ { - blockCommit := makeTestExtCommit(h-1, tmtime.Now()).StripExtensions() + blockCommit := makeTestExtCommit(h-1, tmtime.Now()).ToCommit() block := factory.MakeBlock(state, h, blockCommit) partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(h, tmtime.Now()) - store.SaveBlock(block, partSet, seenCommit) + store.SaveBlockWithExtendedCommit(block, partSet, seenCommit) c3 := store.LoadSeenCommit() require.NotNil(t, c3) require.Equal(t, h, c3.Height) - require.Equal(t, seenCommit.StripExtensions().Hash(), c3.Hash()) + require.Equal(t, seenCommit.ToCommit().Hash(), c3.Hash()) c5 := store.LoadBlockCommit(h) require.Nil(t, c5) c6 := store.LoadBlockCommit(h - 1) diff --git a/internal/test/factory/params.go b/internal/test/factory/params.go index dda8e2b3c..c6fa3f9fc 100644 --- a/internal/test/factory/params.go +++ b/internal/test/factory/params.go @@ -18,5 +18,6 @@ func ConsensusParams() *types.ConsensusParams { VoteDelta: 1 * time.Millisecond, BypassCommitTimeout: true, } + c.ABCI.VoteExtensionsEnableHeight = 1 return c } diff --git a/node/node_test.go b/node/node_test.go index b1d7a9481..245e39b3c 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -526,7 +526,7 @@ func TestMaxProposalBlockSize(t *testing.T) { } state.ChainID = maxChainID - voteSet := types.NewVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators) + voteSet := types.NewExtendedVoteSet(state.ChainID, math.MaxInt64-1, math.MaxInt32, tmproto.PrecommitType, state.Validators) // add maximum amount of signatures to a single commit for i := 0; i < types.MaxVotesCount; i++ { diff --git a/test/e2e/runner/evidence.go b/test/e2e/runner/evidence.go index a71ea14fb..fab1d7b20 100644 --- a/test/e2e/runner/evidence.go +++ b/test/e2e/runner/evidence.go @@ -165,9 +165,9 @@ func generateLightClientAttackEvidence( // create a commit for the forged header blockID := makeBlockID(header.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) - commit, err := factory.MakeExtendedCommit(ctx, blockID, forgedHeight, 0, voteSet, pv, forgedTime) + ec, err := factory.MakeExtendedCommit(ctx, blockID, forgedHeight, 0, voteSet, pv, forgedTime) if err != nil { return nil, err } @@ -176,7 +176,7 @@ func generateLightClientAttackEvidence( ConflictingBlock: &types.LightBlock{ SignedHeader: &types.SignedHeader{ Header: header, - Commit: commit.StripExtensions(), + Commit: ec.ToCommit(), }, ValidatorSet: conflictingVals, }, diff --git a/test/e2e/tests/app_test.go b/test/e2e/tests/app_test.go index ed041e186..20a153c1d 100644 --- a/test/e2e/tests/app_test.go +++ b/test/e2e/tests/app_test.go @@ -190,6 +190,7 @@ func TestApp_Tx(t *testing.T) { func TestApp_VoteExtensions(t *testing.T) { testNode(t, func(ctx context.Context, t *testing.T, node e2e.Node) { + t.Skip() client, err := node.Client() require.NoError(t, err) diff --git a/types/block.go b/types/block.go index 32a4f9a0a..f508245bb 100644 --- a/types/block.go +++ b/types/block.go @@ -757,22 +757,23 @@ func (ecs ExtendedCommitSig) ValidateBasic() error { if len(ecs.Extension) > MaxVoteExtensionSize { return fmt.Errorf("vote extension is too big (max: %d)", MaxVoteExtensionSize) } - if len(ecs.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(ecs.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } return nil } - // We expect there to not be any vote extension or vote extension signature - // on nil or absent votes. - if len(ecs.Extension) != 0 { - return fmt.Errorf("vote extension is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + if len(ecs.ExtensionSignature) == 0 && len(ecs.Extension) != 0 { + return errors.New("vote extension signature absent on vote with extension") } - if len(ecs.ExtensionSignature) != 0 { - return fmt.Errorf("vote extension signature is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + return nil +} + +// EnsureExtensions validates that a vote extensions signature is present for +// this ExtendedCommitSig. +func (ecs ExtendedCommitSig) EnsureExtension() error { + if ecs.BlockIDFlag == BlockIDFlagCommit && len(ecs.ExtensionSignature) == 0 { + return errors.New("vote extension data is missing") } return nil } @@ -916,6 +917,26 @@ func (commit *Commit) Hash() tmbytes.HexBytes { return commit.hash } +// WrappedExtendedCommit wraps a commit as an ExtendedCommit. +// The VoteExtension fields of the resulting value will by nil. +// Wrapping a Commit as an ExtendedCommit is useful when an API +// requires an ExtendedCommit wire type but does not +// need the VoteExtension data. +func (commit *Commit) WrappedExtendedCommit() *ExtendedCommit { + cs := make([]ExtendedCommitSig, len(commit.Signatures)) + for idx, s := range commit.Signatures { + cs[idx] = ExtendedCommitSig{ + CommitSig: s, + } + } + return &ExtendedCommit{ + Height: commit.Height, + Round: commit.Round, + BlockID: commit.BlockID, + ExtendedSignatures: cs, + } +} + // StringIndented returns a string representation of the commit. func (commit *Commit) StringIndented(indent string) string { if commit == nil { @@ -1013,17 +1034,33 @@ func (ec *ExtendedCommit) Clone() *ExtendedCommit { return &ecc } +// ToExtendedVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. +// Panics if any of the votes have invalid or absent vote extension data. +// Inverse of VoteSet.MakeExtendedCommit(). +func (ec *ExtendedCommit) ToExtendedVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewExtendedVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + // ToVoteSet constructs a VoteSet from the Commit and validator set. -// Panics if signatures from the commit can't be added to the voteset. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. // Inverse of VoteSet.MakeExtendedCommit(). func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + +// addSigsToVoteSet adds all of the signature to voteSet. +func (ec *ExtendedCommit) addSigsToVoteSet(voteSet *VoteSet) { for idx, ecs := range ec.ExtendedSignatures { if ecs.BlockIDFlag == BlockIDFlagAbsent { continue // OK, some precommits can be missing. } vote := ec.GetExtendedVote(int32(idx)) - if err := vote.ValidateWithExtension(); err != nil { + if err := vote.ValidateBasic(); err != nil { panic(fmt.Errorf("failed to validate vote reconstructed from LastCommit: %w", err)) } added, err := voteSet.AddVote(vote) @@ -1031,12 +1068,58 @@ func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet panic(fmt.Errorf("failed to reconstruct vote set from extended commit: %w", err)) } } +} + +// ToVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the commit can't be added to the voteset. +// Inverse of VoteSet.MakeCommit(). +func (commit *Commit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewVoteSet(chainID, commit.Height, commit.Round, tmproto.PrecommitType, vals) + for idx, cs := range commit.Signatures { + if cs.BlockIDFlag == BlockIDFlagAbsent { + continue // OK, some precommits can be missing. + } + vote := commit.GetVote(int32(idx)) + if err := vote.ValidateBasic(); err != nil { + panic(fmt.Errorf("failed to validate vote reconstructed from commit: %w", err)) + } + added, err := voteSet.AddVote(vote) + if !added || err != nil { + panic(fmt.Errorf("failed to reconstruct vote set from commit: %w", err)) + } + } return voteSet } -// StripExtensions converts an ExtendedCommit to a Commit by removing all vote +// EnsureExtensions validates that a vote extensions signature is present for +// every ExtendedCommitSig in the ExtendedCommit. +func (ec *ExtendedCommit) EnsureExtensions() error { + for _, ecs := range ec.ExtendedSignatures { + if err := ecs.EnsureExtension(); err != nil { + return err + } + } + return nil +} + +// StripExtensions removes all VoteExtension data from an ExtendedCommit. This +// is useful when dealing with an ExendedCommit but vote extension data is +// expected to be absent. +func (ec *ExtendedCommit) StripExtensions() bool { + stripped := false + for idx := range ec.ExtendedSignatures { + if len(ec.ExtendedSignatures[idx].Extension) > 0 || len(ec.ExtendedSignatures[idx].ExtensionSignature) > 0 { + stripped = true + } + ec.ExtendedSignatures[idx].Extension = nil + ec.ExtendedSignatures[idx].ExtensionSignature = nil + } + return stripped +} + +// ToCommit converts an ExtendedCommit to a Commit by removing all vote // extension-related fields. -func (ec *ExtendedCommit) StripExtensions() *Commit { +func (ec *ExtendedCommit) ToCommit() *Commit { cs := make([]CommitSig, len(ec.ExtendedSignatures)) for idx, ecs := range ec.ExtendedSignatures { cs[idx] = ecs.CommitSig diff --git a/types/block_test.go b/types/block_test.go index 09a8b602e..4c3a74d8f 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -49,7 +49,7 @@ func TestBlockAddEvidence(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - block := MakeBlock(h, txs, extCommit.StripExtensions(), evList) + block := MakeBlock(h, txs, extCommit.ToCommit(), evList) require.NotNil(t, block) require.Equal(t, 1, len(block.Evidence)) require.NotNil(t, block.EvidenceHash) @@ -68,7 +68,7 @@ func TestBlockValidateBasic(t *testing.T) { voteSet, valSet, vals := randVoteSet(ctx, t, h-1, 1, tmproto.PrecommitType, 10, 1) extCommit, err := makeExtCommit(ctx, lastID, h-1, 1, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev, err := NewMockDuplicateVoteEvidenceWithValidator(ctx, h, time.Now(), vals[0], "block-test-chain") require.NoError(t, err) @@ -163,7 +163,7 @@ func TestBlockMakePartSetWithEvidence(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - partSet, err := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.StripExtensions(), evList).MakePartSet(512) + partSet, err := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.ToCommit(), evList).MakePartSet(512) require.NoError(t, err) assert.NotNil(t, partSet) @@ -187,7 +187,7 @@ func TestBlockHashesTo(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - block := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.StripExtensions(), evList) + block := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.ToCommit(), evList) block.ValidatorsHash = valSet.Hash() assert.False(t, block.HashesTo([]byte{})) assert.False(t, block.HashesTo([]byte("something else"))) @@ -483,7 +483,7 @@ func randCommit(ctx context.Context, t *testing.T, now time.Time) *Commit { require.NoError(t, err) - return commit.StripExtensions() + return commit.ToCommit() } func hexBytesFromString(t *testing.T, s string) bytes.HexBytes { @@ -556,33 +556,138 @@ func TestBlockMaxDataBytesNoEvidence(t *testing.T) { } } +// TestVoteSetToExtendedCommit tests that the extended commit produced from a +// vote set contains the same vote information as the vote set. The test ensures +// that the MakeExtendedCommit method behaves as expected, whether vote extensions +// are present in the original votes or not. +func TestVoteSetToExtendedCommit(t *testing.T) { + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + + t.Run(testCase.name, func(t *testing.T) { + blockID := makeBlockIDRandom() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + valSet, vals := randValidatorPrivValSet(ctx, t, 10, 1) + var voteSet *VoteSet + if testCase.includeExtension { + voteSet = NewExtendedVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } + for i := 0; i < len(vals); i++ { + pubKey, err := vals[i].GetPubKey(ctx) + require.NoError(t, err) + vote := &Vote{ + ValidatorAddress: pubKey.Address(), + ValidatorIndex: int32(i), + Height: 3, + Round: 1, + Type: tmproto.PrecommitType, + BlockID: blockID, + Timestamp: time.Now(), + } + v := vote.ToProto() + err = vals[i].SignVote(ctx, voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + if testCase.includeExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + added, err := voteSet.AddVote(vote) + require.NoError(t, err) + require.True(t, added) + } + ec := voteSet.MakeExtendedCommit() + + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := ec.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + } + }) + } +} + +// TestExtendedCommitToVoteSet tests that the vote set produced from an extended commit +// contains the same vote information as the extended commit. The test ensures +// that the ToVoteSet method behaves as expected, whether vote extensions +// are present in the original votes or not. func TestExtendedCommitToVoteSet(t *testing.T) { - lastID := makeBlockIDRandom() - h := int64(3) + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + lastID := makeBlockIDRandom() + h := int64(3) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - voteSet, valSet, vals := randVoteSet(ctx, t, h-1, 1, tmproto.PrecommitType, 10, 1) - extCommit, err := makeExtCommit(ctx, lastID, h-1, 1, voteSet, vals, time.Now()) - assert.NoError(t, err) + voteSet, valSet, vals := randVoteSet(ctx, t, h-1, 1, tmproto.PrecommitType, 10, 1) + extCommit, err := makeExtCommit(ctx, lastID, h-1, 1, voteSet, vals, time.Now()) + assert.NoError(t, err) - chainID := voteSet.ChainID() - voteSet2 := extCommit.ToVoteSet(chainID, valSet) + if !testCase.includeExtension { + for i := 0; i < len(vals); i++ { + v := voteSet.GetByIndex(int32(i)) + v.Extension = nil + v.ExtensionSignature = nil + extCommit.ExtendedSignatures[i].Extension = nil + extCommit.ExtendedSignatures[i].ExtensionSignature = nil + } + } - for i := int32(0); int(i) < len(vals); i++ { - vote1 := voteSet.GetByIndex(i) - vote2 := voteSet2.GetByIndex(i) - vote3 := extCommit.GetExtendedVote(i) + chainID := voteSet.ChainID() + var voteSet2 *VoteSet + if testCase.includeExtension { + voteSet2 = extCommit.ToExtendedVoteSet(chainID, valSet) + } else { + voteSet2 = extCommit.ToVoteSet(chainID, valSet) + } - vote1bz, err := vote1.ToProto().Marshal() - require.NoError(t, err) - vote2bz, err := vote2.ToProto().Marshal() - require.NoError(t, err) - vote3bz, err := vote3.ToProto().Marshal() - require.NoError(t, err) - assert.Equal(t, vote1bz, vote2bz) - assert.Equal(t, vote1bz, vote3bz) + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := voteSet2.GetByIndex(i) + vote3 := extCommit.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + vote3bz, err := vote3.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + assert.Equal(t, vote1bz, vote3bz) + } + }) } } @@ -637,7 +742,7 @@ func TestCommitToVoteSetWithVotesForNilBlock(t *testing.T) { if tc.valid { extCommit := voteSet.MakeExtendedCommit() // panics without > 2/3 valid votes assert.NotNil(t, extCommit) - err := valSet.VerifyCommit(voteSet.ChainID(), blockID, height-1, extCommit.StripExtensions()) + err := valSet.VerifyCommit(voteSet.ChainID(), blockID, height-1, extCommit.ToCommit()) assert.NoError(t, err) } else { assert.Panics(t, func() { voteSet.MakeExtendedCommit() }) diff --git a/types/evidence_test.go b/types/evidence_test.go index 8b06a6218..d014a3ecc 100644 --- a/types/evidence_test.go +++ b/types/evidence_test.go @@ -155,7 +155,7 @@ func TestLightClientAttackEvidenceBasic(t *testing.T) { blockID := makeBlockID(crypto.Checksum([]byte("blockhash")), math.MaxInt32, crypto.Checksum([]byte("partshash"))) extCommit, err := makeExtCommit(ctx, blockID, height, 1, voteSet, privVals, defaultVoteTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() lcae := &LightClientAttackEvidence{ ConflictingBlock: &LightBlock{ @@ -221,7 +221,7 @@ func TestLightClientAttackEvidenceValidation(t *testing.T) { blockID := makeBlockID(header.Hash(), math.MaxInt32, crypto.Checksum([]byte("partshash"))) extCommit, err := makeExtCommit(ctx, blockID, height, 1, voteSet, privVals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() lcae := &LightClientAttackEvidence{ ConflictingBlock: &LightBlock{ @@ -434,7 +434,7 @@ func TestEvidenceVectors(t *testing.T) { ConflictingBlock: &LightBlock{ SignedHeader: &SignedHeader{ Header: header, - Commit: extCommit.StripExtensions(), + Commit: extCommit.ToCommit(), }, ValidatorSet: valSet, }, diff --git a/types/params.go b/types/params.go index e8ee6fcdf..3b5e9a250 100644 --- a/types/params.go +++ b/types/params.go @@ -43,6 +43,7 @@ type ConsensusParams struct { Version VersionParams `json:"version"` Synchrony SynchronyParams `json:"synchrony"` Timeout TimeoutParams `json:"timeout"` + ABCI ABCIParams `json:"abci"` } // HashedParams is a subset of ConsensusParams. @@ -96,6 +97,21 @@ type TimeoutParams struct { BypassCommitTimeout bool `json:"bypass_commit_timeout"` } +// ABCIParams configure ABCI functionality specific to the Application Blockchain +// Interface. +type ABCIParams struct { + VoteExtensionsEnableHeight int64 `json:"vote_extensions_enable_height"` +} + +// VoteExtensionsEnabled returns true if vote extensions are enabled at height h +// and false otherwise. +func (a ABCIParams) VoteExtensionsEnabled(h int64) bool { + if a.VoteExtensionsEnableHeight == 0 { + return false + } + return a.VoteExtensionsEnableHeight <= h +} + // DefaultConsensusParams returns a default ConsensusParams. func DefaultConsensusParams() *ConsensusParams { return &ConsensusParams{ @@ -105,6 +121,7 @@ func DefaultConsensusParams() *ConsensusParams { Version: DefaultVersionParams(), Synchrony: DefaultSynchronyParams(), Timeout: DefaultTimeoutParams(), + ABCI: DefaultABCIParams(), } } @@ -176,6 +193,13 @@ func DefaultTimeoutParams() TimeoutParams { } } +func DefaultABCIParams() ABCIParams { + return ABCIParams{ + // When set to 0, vote extensions are not required. + VoteExtensionsEnableHeight: 0, + } +} + // TimeoutParamsOrDefaults returns the SynchronyParams, filling in any zero values // with the Tendermint defined default values. func (t TimeoutParams) TimeoutParamsOrDefaults() TimeoutParams { diff --git a/types/validation_test.go b/types/validation_test.go index f63c34450..c28f63000 100644 --- a/types/validation_test.go +++ b/types/validation_test.go @@ -153,7 +153,7 @@ func TestValidatorSet_VerifyCommit_CheckAllSignatures(t *testing.T) { voteSet, valSet, vals := randVoteSet(ctx, t, h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) @@ -184,7 +184,7 @@ func TestValidatorSet_VerifyCommitLight_ReturnsAsSoonAsMajorityOfVotingPowerSign voteSet, valSet, vals := randVoteSet(ctx, t, h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) @@ -212,7 +212,7 @@ func TestValidatorSet_VerifyCommitLightTrusting_ReturnsAsSoonAsTrustLevelOfVotin voteSet, valSet, vals := randVoteSet(ctx, t, h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) @@ -239,7 +239,7 @@ func TestValidatorSet_VerifyCommitLightTrusting(t *testing.T) { newValSet, _ = randValidatorPrivValSet(ctx, t, 2, 1) ) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() testCases := []struct { valSet *ValidatorSet @@ -284,7 +284,7 @@ func TestValidatorSet_VerifyCommitLightTrustingErrorsOnOverflow(t *testing.T) { ) require.NoError(t, err) - err = valSet.VerifyCommitLightTrusting("test_chain_id", extCommit.StripExtensions(), + err = valSet.VerifyCommitLightTrusting("test_chain_id", extCommit.ToCommit(), tmmath.Fraction{Numerator: 25, Denominator: 55}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "int64 overflow") diff --git a/types/validator_set_test.go b/types/validator_set_test.go index 8b5846da9..81e81919d 100644 --- a/types/validator_set_test.go +++ b/types/validator_set_test.go @@ -1541,7 +1541,7 @@ func BenchmarkValidatorSet_VerifyCommit_Ed25519(b *testing.B) { // nolint // create a commit with n validators extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(b, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() for i := 0; i < b.N/n; i++ { err = valSet.VerifyCommit(chainID, blockID, h, commit) @@ -1570,7 +1570,7 @@ func BenchmarkValidatorSet_VerifyCommitLight_Ed25519(b *testing.B) { // nolint // create a commit with n validators extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(b, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() for i := 0; i < b.N/n; i++ { err = valSet.VerifyCommitLight(chainID, blockID, h, commit) @@ -1598,7 +1598,7 @@ func BenchmarkValidatorSet_VerifyCommitLightTrusting_Ed25519(b *testing.B) { // create a commit with n validators extCommit, err := makeExtCommit(ctx, blockID, h, 0, voteSet, vals, time.Now()) require.NoError(b, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() for i := 0; i < b.N/n; i++ { err = valSet.VerifyCommitLightTrusting(chainID, commit, tmmath.Fraction{Numerator: 1, Denominator: 3}) diff --git a/types/vote.go b/types/vote.go index 446de130a..f7006b8cd 100644 --- a/types/vote.go +++ b/types/vote.go @@ -27,7 +27,7 @@ var ( ErrVoteInvalidBlockHash = errors.New("invalid block hash") ErrVoteNonDeterministicSignature = errors.New("non-deterministic signature") ErrVoteNil = errors.New("nil vote") - ErrVoteInvalidExtension = errors.New("invalid vote extension") + ErrVoteExtensionAbsent = errors.New("vote extension absent") ) type ErrVoteConflictingVotes struct { @@ -112,6 +112,16 @@ func (vote *Vote) CommitSig() CommitSig { } } +// StripExtension removes any extension data from the vote. Useful if the +// chain has not enabled vote extensions. +// Returns true if extension data was present before stripping and false otherwise. +func (vote *Vote) StripExtension() bool { + stripped := len(vote.Extension) > 0 || len(vote.ExtensionSignature) > 0 + vote.Extension = nil + vote.ExtensionSignature = nil + return stripped +} + // ExtendedCommitSig attempts to construct an ExtendedCommitSig from this vote. // Panics if either the vote extension signature is missing or if the block ID // is not either empty or complete. @@ -120,13 +130,8 @@ func (vote *Vote) ExtendedCommitSig() ExtendedCommitSig { return NewExtendedCommitSigAbsent() } - cs := vote.CommitSig() - if vote.BlockID.IsComplete() && len(vote.ExtensionSignature) == 0 { - panic(fmt.Sprintf("Invalid vote %v - BlockID is complete but missing vote extension signature", vote)) - } - return ExtendedCommitSig{ - CommitSig: cs, + CommitSig: vote.CommitSig(), Extension: vote.Extension, ExtensionSignature: vote.ExtensionSignature, } @@ -230,11 +235,11 @@ func (vote *Vote) Verify(chainID string, pubKey crypto.PubKey) error { return err } -// VerifyWithExtension performs the same verification as Verify, but +// VerifyVoteAndExtension performs the same verification as Verify, but // additionally checks whether the vote extension signature corresponds to the // given chain ID and public key. We only verify vote extension signatures for // precommits. -func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) error { +func (vote *Vote) VerifyVoteAndExtension(chainID string, pubKey crypto.PubKey) error { v, err := vote.verifyAndReturnProto(chainID, pubKey) if err != nil { return err @@ -249,6 +254,20 @@ func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) erro return nil } +// VerifyExtension checks whether the vote extension signature corresponds to the +// given chain ID and public key. +func (vote *Vote) VerifyExtension(chainID string, pubKey crypto.PubKey) error { + if vote.Type != tmproto.PrecommitType || vote.BlockID.IsNil() { + return nil + } + v := vote.ToProto() + extSignBytes := VoteExtensionSignBytes(chainID, v) + if !pubKey.VerifySignature(extSignBytes, vote.ExtensionSignature) { + return ErrVoteInvalidSignature + } + return nil +} + // ValidateBasic checks whether the vote is well-formed. It does not, however, // check vote extensions - for vote validation with vote extension validation, // use ValidateWithExtension. @@ -306,30 +325,34 @@ func (vote *Vote) ValidateBasic() error { } } - return nil -} - -// ValidateWithExtension performs the same validations as ValidateBasic, but -// additionally checks whether a vote extension signature is present. This -// function is used in places where vote extension signatures are expected. -func (vote *Vote) ValidateWithExtension() error { - if err := vote.ValidateBasic(); err != nil { - return err - } - - // We should always see vote extension signatures in non-nil precommits if vote.Type == tmproto.PrecommitType && !vote.BlockID.IsNil() { - if len(vote.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(vote.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } + if len(vote.ExtensionSignature) == 0 && len(vote.Extension) != 0 { + return fmt.Errorf("vote extension signature absent on vote with extension") + } } return nil } +// EnsureExtension checks for the presence of extensions signature data +// on precommit vote types. +func (vote *Vote) EnsureExtension() error { + // We should always see vote extension signatures in non-nil precommits + if vote.Type != tmproto.PrecommitType { + return nil + } + if vote.BlockID.IsNil() { + return nil + } + if len(vote.ExtensionSignature) > 0 { + return nil + } + return ErrVoteExtensionAbsent +} + // ToProto converts the handwritten type to proto generated type // return type, nil if everything converts safely, otherwise nil, error func (vote *Vote) ToProto() *tmproto.Vote { diff --git a/types/vote_set.go b/types/vote_set.go index 224d4e4f8..6d83ac85d 100644 --- a/types/vote_set.go +++ b/types/vote_set.go @@ -3,6 +3,7 @@ package types import ( "bytes" "encoding/json" + "errors" "fmt" "strings" "sync" @@ -53,11 +54,12 @@ const ( NOTE: Assumes that the sum total of voting power does not exceed MaxUInt64. */ type VoteSet struct { - chainID string - height int64 - round int32 - signedMsgType tmproto.SignedMsgType - valSet *ValidatorSet + chainID string + height int64 + round int32 + signedMsgType tmproto.SignedMsgType + valSet *ValidatorSet + extensionsEnabled bool mtx sync.Mutex votesBitArray *bits.BitArray @@ -68,7 +70,8 @@ type VoteSet struct { peerMaj23s map[string]BlockID // Maj23 for each peer } -// Constructs a new VoteSet struct used to accumulate votes for given height/round. +// NewVoteSet instantiates all fields of a new vote set. This constructor requires +// that no vote extension data be present on the votes that are added to the set. func NewVoteSet(chainID string, height int64, round int32, signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { if height == 0 { @@ -89,6 +92,16 @@ func NewVoteSet(chainID string, height int64, round int32, } } +// NewExtendedVoteSet constructs a vote set with additional vote verification logic. +// The VoteSet constructed with NewExtendedVoteSet verifies the vote extension +// data for every vote added to the set. +func NewExtendedVoteSet(chainID string, height int64, round int32, + signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { + vs := NewVoteSet(chainID, height, round, signedMsgType, valSet) + vs.extensionsEnabled = true + return vs +} + func (voteSet *VoteSet) ChainID() string { return voteSet.chainID } @@ -194,8 +207,17 @@ func (voteSet *VoteSet) addVote(vote *Vote) (added bool, err error) { } // Check signature. - if err := vote.VerifyWithExtension(voteSet.chainID, val.PubKey); err != nil { - return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + if voteSet.extensionsEnabled { + if err := vote.VerifyVoteAndExtension(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } + } else { + if err := vote.Verify(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } + if len(vote.ExtensionSignature) > 0 || len(vote.Extension) > 0 { + return false, errors.New("unexpected vote extension data present in vote") + } } // Add vote and get conflicting vote if any. diff --git a/types/vote_set_test.go b/types/vote_set_test.go index 8d166d508..e35da7491 100644 --- a/types/vote_set_test.go +++ b/types/vote_set_test.go @@ -498,6 +498,92 @@ func TestVoteSet_MakeCommit(t *testing.T) { } } +// TestVoteSet_VoteExtensionsEnabled tests that the vote set correctly validates +// vote extensions data when either required or not required. +func TestVoteSet_VoteExtensionsEnabled(t *testing.T) { + for _, tc := range []struct { + name string + requireExtensions bool + addExtension bool + exepectError bool + }{ + { + name: "no extension but expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "invalid extensions but not expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "no extension and not expected", + requireExtensions: false, + addExtension: false, + exepectError: false, + }, + { + name: "extension and expected", + requireExtensions: true, + addExtension: true, + exepectError: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + height, round := int64(1), int32(0) + valSet, privValidators := randValidatorPrivValSet(ctx, t, 5, 10) + var voteSet *VoteSet + if tc.requireExtensions { + voteSet = NewExtendedVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } + + val0 := privValidators[0] + + val0p, err := val0.GetPubKey(ctx) + require.NoError(t, err) + val0Addr := val0p.Address() + blockHash := crypto.CRandBytes(32) + blockPartsTotal := uint32(123) + blockPartSetHeader := PartSetHeader{blockPartsTotal, crypto.CRandBytes(32)} + + vote := &Vote{ + ValidatorAddress: val0Addr, + ValidatorIndex: 0, + Height: height, + Round: round, + Type: tmproto.PrecommitType, + Timestamp: tmtime.Now(), + BlockID: BlockID{blockHash, blockPartSetHeader}, + } + v := vote.ToProto() + err = val0.SignVote(ctx, voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + + if tc.addExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + + added, err := voteSet.AddVote(vote) + if tc.exepectError { + require.Error(t, err) + require.False(t, added) + } else { + require.NoError(t, err) + require.True(t, added) + } + }) + } +} + // NOTE: privValidators are in order func randVoteSet( ctx context.Context, @@ -510,7 +596,7 @@ func randVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := randValidatorPrivValSet(ctx, t, numValidators, votingPower) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators + return NewExtendedVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators } func deterministicVoteSet( @@ -523,7 +609,7 @@ func deterministicVoteSet( ) (*VoteSet, *ValidatorSet, []PrivValidator) { t.Helper() valSet, privValidators := deterministicValidatorSet(ctx, t) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators + return NewExtendedVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators } func randValidatorPrivValSet(ctx context.Context, t testing.TB, numValidators int, votingPower int64) (*ValidatorSet, []PrivValidator) { diff --git a/types/vote_test.go b/types/vote_test.go index 70cd91381..d0819d7c4 100644 --- a/types/vote_test.go +++ b/types/vote_test.go @@ -267,7 +267,7 @@ func TestVoteExtension(t *testing.T) { if tc.includeSignature { vote.ExtensionSignature = v.ExtensionSignature } - err = vote.VerifyWithExtension("test_chain_id", pk) + err = vote.VerifyExtension("test_chain_id", pk) if tc.expectError { require.Error(t, err) } else { @@ -361,7 +361,7 @@ func TestValidVotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", tc.vote) tc.malleateVote(tc.vote) require.NoError(t, tc.vote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.NoError(t, tc.vote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, tc.vote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -387,13 +387,13 @@ func TestInvalidVotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s in invalid prevote", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s in invalid prevote", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s in invalid prevote", tc.name) precommit := examplePrecommit(t) signVote(ctx, t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s in invalid precommit", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s in invalid precommit", tc.name) + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s in invalid precommit", tc.name) } } @@ -414,7 +414,7 @@ func TestInvalidPrevotes(t *testing.T) { signVote(ctx, t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -431,18 +431,44 @@ func TestInvalidPrecommitExtensions(t *testing.T) { v.Extension = []byte("extension") v.ExtensionSignature = nil }}, - // TODO(thane): Re-enable once https://github.com/tendermint/tendermint/issues/8272 is resolved - //{"missing vote extension signature", func(v *Vote) { v.ExtensionSignature = nil }}, {"oversized vote extension signature", func(v *Vote) { v.ExtensionSignature = make([]byte, MaxSignatureSize+1) }}, } for _, tc := range testCases { precommit := examplePrecommit(t) signVote(ctx, t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) - // We don't expect an error from ValidateBasic, because it doesn't - // handle vote extensions. - require.NoError(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + // ValidateBasic ensures that vote extensions, if present, are well formed + require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) + } +} + +func TestEnsureVoteExtension(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + privVal := NewMockPV() + + testCases := []struct { + name string + malleateVote func(*Vote) + expectError bool + }{ + {"vote extension signature absent", func(v *Vote) { + v.Extension = nil + v.ExtensionSignature = nil + }, true}, + {"vote extension signature present", func(v *Vote) { + v.ExtensionSignature = []byte("extension signature") + }, false}, + } + for _, tc := range testCases { + precommit := examplePrecommit(t) + signVote(ctx, t, privVal, "test_chain_id", precommit) + tc.malleateVote(precommit) + if tc.expectError { + require.Error(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } else { + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } } } @@ -497,7 +523,7 @@ func getSampleCommit(ctx context.Context, t testing.TB) *Commit { require.NoError(t, err) - return commit.StripExtensions() + return commit.ToCommit() } func BenchmarkVoteSignBytes(b *testing.B) {