From 03bdcad31e9ef8ced3683d602116f2cf30f032d8 Mon Sep 17 00:00:00 2001 From: Sergio Mena Date: Fri, 9 Dec 2022 22:30:37 +0100 Subject: [PATCH] Vote extensions: Add consensus param for extension activation logic (#9862) * [cherry-picked] abci++: add consensus parameter logic to control vote extension require height (#8547) This PR makes vote extensions optional within Tendermint. A new ConsensusParams field, called ABCIParams.VoteExtensionsEnableHeight, has been added to toggle whether or not extensions should be enabled or disabled depending on the current height of the consensus engine. Related to: #8453 * Fix UTs * fix blocksync reactor import of state store * fixes1 * fixed_more_UTs * Fix TestHandshakeReplaySome * Fix all unit tests * Added hunk in original commit Co-authored-by: William Banfield <4561443+williambanfield@users.noreply.github.com> Co-authored-by: Callum Waters --- blocksync/pool.go | 6 +- blocksync/reactor.go | 120 +++++++---- blocksync/reactor_test.go | 4 +- consensus/byzantine_test.go | 2 +- consensus/common_test.go | 49 ++++- consensus/mempool_test.go | 10 +- consensus/reactor.go | 23 +- consensus/reactor_test.go | 104 ++++++++- consensus/replay_test.go | 8 +- consensus/state.go | 138 +++++++++--- consensus/state_test.go | 267 ++++++++++++++++++------ consensus/types/height_vote_set.go | 26 ++- consensus/types/height_vote_set_test.go | 2 +- evidence/pool_test.go | 7 +- evidence/verify_test.go | 24 +-- internal/test/params.go | 13 ++ state/execution.go | 22 +- state/execution_test.go | 154 ++++++++++++-- state/mocks/block_store.go | 7 +- state/rollback_test.go | 6 +- state/services.go | 3 +- state/store.go | 26 +++ state/validation_test.go | 6 +- store/store.go | 73 +++++-- store/store_test.go | 113 ++++++++-- test/e2e/runner/evidence.go | 6 +- test/e2e/tests/app_test.go | 1 + types/block.go | 109 ++++++++-- types/block_test.go | 154 +++++++++++--- types/evidence_test.go | 4 +- types/params.go | 24 +++ types/validation_test.go | 10 +- types/vote.go | 71 ++++--- types/vote_set.go | 38 +++- types/vote_set_test.go | 85 +++++++- types/vote_test.go | 51 +++-- 36 files changed, 1410 insertions(+), 356 deletions(-) create mode 100644 internal/test/params.go diff --git a/blocksync/pool.go b/blocksync/pool.go index b6a3f3db0..147ce0603 100644 --- a/blocksync/pool.go +++ b/blocksync/pool.go @@ -259,7 +259,7 @@ func (pool *BlockPool) AddBlock(peerID p2p.ID, block *types.Block, extCommit *ty pool.mtx.Lock() defer pool.mtx.Unlock() - if block.Height != extCommit.Height { + if extCommit != nil && block.Height != extCommit.Height { return fmt.Errorf("heights don't match, not adding block (block height: %d, commit height: %d)", block.Height, extCommit.Height) } @@ -566,7 +566,9 @@ func (bpr *bpRequester) setBlock(block *types.Block, extCommit *types.ExtendedCo return false } bpr.block = block - bpr.extCommit = extCommit + if extCommit != nil { + bpr.extCommit = extCommit + } bpr.mtx.Unlock() select { diff --git a/blocksync/reactor.go b/blocksync/reactor.go index 63e1fc56e..8574e464f 100644 --- a/blocksync/reactor.go +++ b/blocksync/reactor.go @@ -177,31 +177,40 @@ func (bcR *Reactor) respondToPeer(msg *bcproto.BlockRequest, src p2p.Peer) (queued bool) { block := bcR.store.LoadBlock(msg.Height) - if block != nil { - extCommit := bcR.store.LoadBlockExtendedCommit(msg.Height) - if extCommit == nil { - bcR.Logger.Error("found block in store without extended commit", "block", block) - return false - } - bl, err := block.ToProto() - if err != nil { - bcR.Logger.Error("could not convert msg to protobuf", "err", err) - return false - } - + if block == nil { + bcR.Logger.Info("Peer asking for a block we don't have", "src", src, "height", msg.Height) return src.TrySend(p2p.Envelope{ ChannelID: BlocksyncChannel, - Message: &bcproto.BlockResponse{ - Block: bl, - ExtCommit: extCommit.ToProto(), - }, + Message: &bcproto.NoBlockResponse{Height: msg.Height}, }) } - bcR.Logger.Info("Peer asking for a block we don't have", "src", src, "height", msg.Height) + state, err := bcR.blockExec.Store().Load() + if err != nil { + bcR.Logger.Error("loading state", "err", err) + return false + } + var extCommit *types.ExtendedCommit + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(msg.Height) { + extCommit = bcR.store.LoadBlockExtendedCommit(msg.Height) + if extCommit == nil { + bcR.Logger.Error("found block in store with no extended commit", "block", block) + return false + } + } + + bl, err := block.ToProto() + if err != nil { + bcR.Logger.Error("could not convert msg to protobuf", "err", err) + return false + } + return src.TrySend(p2p.Envelope{ ChannelID: BlocksyncChannel, - Message: &bcproto.NoBlockResponse{Height: msg.Height}, + Message: &bcproto.BlockResponse{ + Block: bl, + ExtCommit: extCommit.ToProto(), + }, }) } @@ -224,12 +233,16 @@ func (bcR *Reactor) Receive(e p2p.Envelope) { bcR.Logger.Error("Block content is invalid", "err", err) return } - extCommit, err := types.ExtendedCommitFromProto(msg.ExtCommit) - if err != nil { - bcR.Logger.Error("failed to convert extended commit from proto", - "peer", e.Src, - "err", err) - return + var extCommit *types.ExtendedCommit + if msg.ExtCommit != nil { + var err error + extCommit, err = types.ExtendedCommitFromProto(msg.ExtCommit) + if err != nil { + bcR.Logger.Error("failed to convert extended commit from proto", + "peer", e.Src, + "err", err) + return + } } if err := bcR.pool.AddBlock(e.Src.ID(), bi, extCommit, msg.Block.Size()); err != nil { @@ -279,6 +292,8 @@ func (bcR *Reactor) poolRoutine(stateSynced bool) { didProcessCh := make(chan struct{}, 1) + initialCommitHasExtensions := (bcR.initialState.LastBlockHeight > 0 && bcR.store.LoadBlockExtendedCommit(bcR.initialState.LastBlockHeight) != nil) + go func() { for { select { @@ -321,11 +336,26 @@ FOR_LOOP: bcR.Logger.Debug("Consensus ticker", "numPending", numPending, "total", lenRequesters, "outbound", outbound, "inbound", inbound) - // TODO(sergio) Might be needed for implementing the upgrading solution. Remove after that - if state.LastBlockHeight > 0 && blocksSynced == 0 { - // Having state-synced, we need to blocksync at least one block + // The "if" statement below is a bit confusing, so here is a breakdown + // of its logic and purpose: + // + // If VoteExtensions are enabled we cannot switch to consensus without + // the vote extension data for the previous height, i.e. state.LastBlockHeight. + // + // If extensions were required during state.LastBlockHeight and we have + // sync'd at least one block, then we are guaranteed to have extensions. + // BlockSync requires that the blocks it fetches have extensions if + // extensions were enabled during the height. + // + // If extensions were required during state.LastBlockHeight and we have + // not sync'd any blocks, then we can only transition to Consensus + // if we already had extensions for the initial height. + // If any of these conditions is not met, we continue the loop, looking + // for extensions. + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(state.LastBlockHeight) && + (blocksSynced == 0 && !initialCommitHasExtensions) { bcR.Logger.Info( - "no seen commit yet", + "no extended commit yet", "height", height, "last_block_height", state.LastBlockHeight, "initial_height", state.InitialHeight, @@ -366,19 +396,19 @@ FOR_LOOP: // See if there are any blocks to sync. first, second, extCommit := bcR.pool.PeekTwoBlocks() - // bcR.Logger.Info("TrySync peeked", "first", first, "second", second) - if first == nil || second == nil || extCommit == nil { - if first != nil && extCommit == nil { - // See https://github.com/tendermint/tendermint/pull/8433#discussion_r866790631 - panic(fmt.Errorf("peeked first block without extended commit at height %d - possible node store corruption", first.Height)) - } - // we need all to sync the first block + if first != nil && extCommit == nil && + state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + // See https://github.com/tendermint/tendermint/pull/8433#discussion_r866790631 + panic(fmt.Errorf("peeked first block without extended commit at height %d - possible node store corruption", first.Height)) + } else if first == nil || second == nil { + // we need to have fetched two consecutive blocks in order to + // perform blocksync verification continue FOR_LOOP - } else { - // Try again quickly next loop. - didProcessCh <- struct{}{} } + // Try again quickly next loop. + didProcessCh <- struct{}{} + firstParts, err := first.MakePartSet(types.BlockPartSizeBytes) if err != nil { bcR.Logger.Error("failed to make ", @@ -400,6 +430,10 @@ FOR_LOOP: // validate the block before we persist it err = bcR.blockExec.ValidateBlock(state, first) } + if err == nil && state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + // if vote extensions were required at this height, ensure they exist. + err = extCommit.EnsureExtensions() + } if err != nil { bcR.Logger.Error("Error in validation", "err", err) @@ -423,7 +457,15 @@ FOR_LOOP: bcR.pool.PopRequest() // TODO: batch saves so we dont persist to disk every block - bcR.store.SaveBlock(first, firstParts, extCommit) + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(first.Height) { + bcR.store.SaveBlockWithExtendedCommit(first, firstParts, extCommit) + } else { + // We use LastCommit here instead of extCommit. extCommit is not + // guaranteed to be populated by the peer if extensions are not enabled. + // Currently, the peer should provide an extCommit even if the vote extension data are absent + // but this may change so using second.LastCommit is safer. + bcR.store.SaveBlock(first, firstParts, second.LastCommit) + } // TODO: same thing for app - but we would need a way to // get the hash without persisting the state diff --git a/blocksync/reactor_test.go b/blocksync/reactor_test.go index c40cd06a7..96feb8b9a 100644 --- a/blocksync/reactor_test.go +++ b/blocksync/reactor_test.go @@ -118,7 +118,7 @@ func newReactor( for blockHeight := int64(1); blockHeight <= maxBlockHeight; blockHeight++ { lastExtCommit = seenExtCommit.Clone() - thisBlock := state.MakeBlock(blockHeight, nil, lastExtCommit.StripExtensions(), nil, state.Validators.Proposer.Address) + thisBlock := state.MakeBlock(blockHeight, nil, lastExtCommit.ToCommit(), nil, state.Validators.Proposer.Address) thisParts, err := thisBlock.MakePartSet(types.BlockPartSizeBytes) require.NoError(t, err) @@ -148,7 +148,7 @@ func newReactor( panic(fmt.Errorf("error apply block: %w", err)) } - blockStore.SaveBlock(thisBlock, thisParts, seenExtCommit) + blockStore.SaveBlockWithExtendedCommit(thisBlock, thisParts, seenExtCommit) } bcReactor := NewReactor(state.Copy(), blockExec, blockStore, fastSync, NopMetrics()) diff --git a/consensus/byzantine_test.go b/consensus/byzantine_test.go index 52e97bdf7..7dd7e6479 100644 --- a/consensus/byzantine_test.go +++ b/consensus/byzantine_test.go @@ -46,7 +46,7 @@ func TestByzantinePrevoteEquivocation(t *testing.T) { tickerFunc := newMockTickerFunc(true) appFunc := newKVStore - genDoc, privVals := randGenesisDoc(nValidators, false, 30) + genDoc, privVals := randGenesisDoc(nValidators, false, 30, nil) css := make([]*State, nValidators) for i := 0; i < nValidators; i++ { diff --git a/consensus/common_test.go b/consensus/common_test.go index d020cf184..0b8074ed8 100644 --- a/consensus/common_test.go +++ b/consensus/common_test.go @@ -463,9 +463,27 @@ func randState(nValidators int) (*State, []*validatorStub) { return randStateWithApp(nValidators, kvstore.NewInMemoryApplication()) } +func randStateWithAppWithHeight( + nValidators int, + app abci.Application, + height int64, +) (*State, []*validatorStub) { + c := test.ConsensusParams() + c.ABCI.VoteExtensionsEnableHeight = height + return randStateWithAppImpl(nValidators, app, c) +} func randStateWithApp(nValidators int, app abci.Application) (*State, []*validatorStub) { + c := test.ConsensusParams() + return randStateWithAppImpl(nValidators, app, c) +} + +func randStateWithAppImpl( + nValidators int, + app abci.Application, + consensusParams *types.ConsensusParams, +) (*State, []*validatorStub) { // Get State - state, privVals := randGenesisState(nValidators, false, 10) + state, privVals := randGenesisState(nValidators, false, 10, consensusParams) vss := make([]*validatorStub, nValidators) @@ -751,7 +769,7 @@ func consensusLogger() log.Logger { func randConsensusNet(t *testing.T, nValidators int, testName string, tickerFunc func() TimeoutTicker, appFunc func() abci.Application, configOpts ...func(*cfg.Config)) ([]*State, cleanupFunc) { t.Helper() - genDoc, privVals := randGenesisDoc(nValidators, false, 30) + genDoc, privVals := randGenesisDoc(nValidators, false, 30, nil) css := make([]*State, nValidators) logger := consensusLogger() configRootDirs := make([]string, 0, nValidators) @@ -792,7 +810,8 @@ func randConsensusNetWithPeers( tickerFunc func() TimeoutTicker, appFunc func(string) abci.Application, ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) { - genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower) + c := test.ConsensusParams() + genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower, c) css := make([]*State, nPeers) logger := consensusLogger() var peer0Config *cfg.Config @@ -858,7 +877,11 @@ func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int { //------------------------------------------------------------------------------- // genesis -func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) { +func randGenesisDoc(numValidators int, + randPower bool, + minPower int64, + consensusParams *types.ConsensusParams, +) (*types.GenesisDoc, []types.PrivValidator) { validators := make([]types.GenesisValidator, numValidators) privValidators := make([]types.PrivValidator, numValidators) for i := 0; i < numValidators; i++ { @@ -872,15 +895,21 @@ func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.G sort.Sort(types.PrivValidatorsByAddress(privValidators)) return &types.GenesisDoc{ - GenesisTime: tmtime.Now(), - InitialHeight: 1, - ChainID: test.DefaultTestChainID, - Validators: validators, + GenesisTime: tmtime.Now(), + InitialHeight: 1, + ChainID: test.DefaultTestChainID, + Validators: validators, + ConsensusParams: consensusParams, }, privValidators } -func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) { - genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower) +func randGenesisState( + numValidators int, + randPower bool, + minPower int64, + consensusParams *types.ConsensusParams, +) (sm.State, []types.PrivValidator) { + genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower, consensusParams) s0, _ := sm.MakeGenesisState(genDoc) return s0, privValidators } diff --git a/consensus/mempool_test.go b/consensus/mempool_test.go index a729e5b2d..c810a4527 100644 --- a/consensus/mempool_test.go +++ b/consensus/mempool_test.go @@ -29,7 +29,7 @@ func TestMempoolNoProgressUntilTxsAvailable(t *testing.T) { config := ResetConfig("consensus_mempool_txs_available_test") defer os.RemoveAll(config.RootDir) config.Consensus.CreateEmptyBlocks = false - state, privVals := randGenesisState(1, false, 10) + state, privVals := randGenesisState(1, false, 10, nil) app := kvstore.NewInMemoryApplication() resp, err := app.Info(context.Background(), proxy.RequestInfo) require.NoError(t, err) @@ -53,7 +53,7 @@ func TestMempoolProgressAfterCreateEmptyBlocksInterval(t *testing.T) { defer os.RemoveAll(config.RootDir) config.Consensus.CreateEmptyBlocksInterval = ensureTimeout - state, privVals := randGenesisState(1, false, 10) + state, privVals := randGenesisState(1, false, 10, nil) app := kvstore.NewInMemoryApplication() resp, err := app.Info(context.Background(), proxy.RequestInfo) require.NoError(t, err) @@ -74,7 +74,7 @@ func TestMempoolProgressInHigherRound(t *testing.T) { config := ResetConfig("consensus_mempool_txs_available_test") defer os.RemoveAll(config.RootDir) config.Consensus.CreateEmptyBlocks = false - state, privVals := randGenesisState(1, false, 10) + state, privVals := randGenesisState(1, false, 10, nil) cs := newStateWithConfig(config, state, privVals[0], kvstore.NewInMemoryApplication()) assertMempool(cs.txNotifier).EnableTxsAvailable() height, round := cs.Height, cs.Round @@ -116,7 +116,7 @@ func deliverTxsRange(t *testing.T, cs *State, start, end int) { } func TestMempoolTxConcurrentWithCommit(t *testing.T) { - state, privVals := randGenesisState(1, false, 10) + state, privVals := randGenesisState(1, false, 10, nil) blockDB := dbm.NewMemDB() stateStore := sm.NewStore(blockDB, sm.StoreOptions{DiscardABCIResponses: false}) cs := newStateWithConfigAndBlockStore(config, state, privVals[0], kvstore.NewInMemoryApplication(), blockDB) @@ -140,7 +140,7 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { } func TestMempoolRmBadTx(t *testing.T) { - state, privVals := randGenesisState(1, false, 10) + state, privVals := randGenesisState(1, false, 10, nil) app := kvstore.NewInMemoryApplication() blockDB := dbm.NewMemDB() stateStore := sm.NewStore(blockDB, sm.StoreOptions{DiscardABCIResponses: false}) diff --git a/consensus/reactor.go b/consensus/reactor.go index a027bac50..dbcb0094c 100644 --- a/consensus/reactor.go +++ b/consensus/reactor.go @@ -737,11 +737,18 @@ OUTER_LOOP: if blockStoreBase > 0 && prs.Height != 0 && rs.Height >= prs.Height+2 && prs.Height >= blockStoreBase { // Load the block's extended commit for prs.Height, // which contains precommit signatures for prs.Height. - if ec := conR.conS.blockStore.LoadBlockExtendedCommit(prs.Height); ec != nil { - if ps.PickSendVote(ec) { - logger.Debug("Picked Catchup commit to send", "height", prs.Height) - continue OUTER_LOOP - } + var ec *types.ExtendedCommit + if conR.conS.state.ConsensusParams.ABCI.VoteExtensionsEnabled(prs.Height) { + ec = conR.conS.blockStore.LoadBlockExtendedCommit(prs.Height) + } else { + ec = conR.conS.blockStore.LoadBlockCommit(prs.Height).WrappedExtendedCommit() + } + if ec == nil { + continue + } + if ps.PickSendVote(ec) { + logger.Debug("Picked Catchup commit to send", "height", prs.Height) + continue OUTER_LOOP } } @@ -1685,11 +1692,7 @@ type VoteMessage struct { // ValidateBasic checks whether the vote within the message is well-formed. func (m *VoteMessage) ValidateBasic() error { - // Here we validate votes with vote extensions, since we require vote - // extensions to be sent in precommit messages during consensus. Prevote - // messages should never have vote extensions, and this is also validated - // here. - return m.Vote.ValidateWithExtension() + return m.Vote.ValidateBasic() } // String returns a string representation. diff --git a/consensus/reactor_test.go b/consensus/reactor_test.go index 1f52df754..5f1a346eb 100644 --- a/consensus/reactor_test.go +++ b/consensus/reactor_test.go @@ -133,7 +133,7 @@ func TestReactorWithEvidence(t *testing.T) { // to unroll unwieldy abstractions. Here we duplicate the code from: // css := randConsensusNet(N, "consensus_reactor_test", newMockTickerFunc(true), newKVStore) - genDoc, privVals := randGenesisDoc(nValidators, false, 30) + genDoc, privVals := randGenesisDoc(nValidators, false, 30, nil) css := make([]*State, nValidators) logger := consensusLogger() for i := 0; i < nValidators; i++ { @@ -309,6 +309,108 @@ func TestReactorReceivePanicsIfInitPeerHasntBeenCalledYet(t *testing.T) { }) } +// TestSwitchToConsensusVoteExtensions tests that the SwitchToConsensus correctly +// checks for vote extension data when required. +func TestSwitchToConsensusVoteExtensions(t *testing.T) { + for _, testCase := range []struct { + name string + storedHeight int64 + initialRequiredHeight int64 + includeExtensions bool + shouldPanic bool + }{ + { + name: "no vote extensions but not required", + initialRequiredHeight: 0, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions but required this height", + initialRequiredHeight: 2, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "no vote extensions and required in future", + initialRequiredHeight: 3, + storedHeight: 2, + includeExtensions: false, + shouldPanic: false, + }, + { + name: "no vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: false, + shouldPanic: true, + }, + { + name: "vote extensions and required previous height", + initialRequiredHeight: 1, + storedHeight: 2, + includeExtensions: true, + shouldPanic: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + cs, vs := randState(1) + validator := vs[0] + validator.Height = testCase.storedHeight + + cs.state.LastBlockHeight = testCase.storedHeight + cs.state.LastValidators = cs.state.Validators.Copy() + cs.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.initialRequiredHeight + + propBlock, err := cs.createProposalBlock() + require.NoError(t, err) + + // Consensus is preparing to do the next height after the stored height. + cs.Height = testCase.storedHeight + 1 + propBlock.Height = testCase.storedHeight + blockParts, err := propBlock.MakePartSet(types.BlockPartSizeBytes) + require.NoError(t, err) + + var voteSet *types.VoteSet + if testCase.includeExtensions { + voteSet = types.NewExtendedVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } else { + voteSet = types.NewVoteSet(cs.state.ChainID, testCase.storedHeight, 0, tmproto.PrecommitType, cs.state.Validators) + } + signedVote := signVote(validator, tmproto.PrecommitType, propBlock.Hash(), blockParts.Header()) + + if !testCase.includeExtensions { + signedVote.Extension = nil + signedVote.ExtensionSignature = nil + } + + added, err := voteSet.AddVote(signedVote) + require.NoError(t, err) + require.True(t, added) + + if testCase.includeExtensions { + cs.blockStore.SaveBlockWithExtendedCommit(propBlock, blockParts, voteSet.MakeExtendedCommit()) + } else { + cs.blockStore.SaveBlock(propBlock, blockParts, voteSet.MakeExtendedCommit().ToCommit()) + } + reactor := NewReactor( + cs, + true, + ) + + if testCase.shouldPanic { + assert.Panics(t, func() { + reactor.SwitchToConsensus(cs.state, false) + }) + } else { + reactor.SwitchToConsensus(cs.state, false) + } + }) + } +} + // Test we record stats about votes and block parts from other peers. func TestReactorRecordsVotesAndBlockParts(t *testing.T) { N := 4 diff --git a/consensus/replay_test.go b/consensus/replay_test.go index a51e39f67..19379da52 100644 --- a/consensus/replay_test.go +++ b/consensus/replay_test.go @@ -1165,14 +1165,16 @@ func (bs *mockBlockStore) LoadBlockMeta(height int64) *types.BlockMeta { } } func (bs *mockBlockStore) LoadBlockPart(height int64, index int) *types.Part { return nil } -func (bs *mockBlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (bs *mockBlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +} +func (bs *mockBlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { } func (bs *mockBlockStore) LoadBlockCommit(height int64) *types.Commit { - return bs.extCommits[height-1].StripExtensions() + return bs.extCommits[height-1].ToCommit() } func (bs *mockBlockStore) LoadSeenCommit(height int64) *types.Commit { - return bs.extCommits[height-1].StripExtensions() + return bs.extCommits[height-1].ToCommit() } func (bs *mockBlockStore) LoadBlockExtendedCommit(height int64) *types.ExtendedCommit { return bs.extCommits[height-1] diff --git a/consensus/state.go b/consensus/state.go index 8526c87bf..61c0a3a14 100644 --- a/consensus/state.go +++ b/consensus/state.go @@ -561,23 +561,54 @@ func (cs *State) sendInternalMessage(mi msgInfo) { } } -// Reconstruct LastCommit from SeenCommit, which we saved along with the block, -// (which happens even before saving the state) +// Reconstruct the LastCommit from either SeenCommit or the ExtendedCommit. SeenCommit +// and ExtendedCommit are saved along with the block. If VoteExtensions are required +// the method will panic on an absent ExtendedCommit or an ExtendedCommit without +// extension data. func (cs *State) reconstructLastCommit(state sm.State) { - extCommit := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) - if extCommit == nil { - panic(fmt.Sprintf( - "failed to reconstruct last commit; seen commit for height %v not found", - state.LastBlockHeight, - )) + extensionsEnabled := cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(state.LastBlockHeight) + if !extensionsEnabled { + votes, err := cs.votesFromSeenCommit(state) + if err != nil { + panic(fmt.Sprintf("failed to reconstruct last commit; %s", err)) + } + cs.LastCommit = votes + return } - lastPrecommits := extCommit.ToVoteSet(state.ChainID, state.LastValidators) - if !lastPrecommits.HasTwoThirdsMajority() { - panic("failed to reconstruct last commit; does not have +2/3 maj") + votes, err := cs.votesFromExtendedCommit(state) + if err != nil { + panic(fmt.Sprintf("failed to reconstruct last extended commit; %s", err)) + } + cs.LastCommit = votes +} + +func (cs *State) votesFromExtendedCommit(state sm.State) (*types.VoteSet, error) { + ec := cs.blockStore.LoadBlockExtendedCommit(state.LastBlockHeight) + if ec == nil { + return nil, fmt.Errorf("extended commit for height %v not found", state.LastBlockHeight) + } + vs := ec.ToExtendedVoteSet(state.ChainID, state.LastValidators) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("extended commit does not have +2/3 majority") + } + return vs, nil +} + +func (cs *State) votesFromSeenCommit(state sm.State) (*types.VoteSet, error) { + commit := cs.blockStore.LoadSeenCommit(state.LastBlockHeight) + if commit == nil { + commit = cs.blockStore.LoadBlockCommit(state.LastBlockHeight) + } + if commit == nil { + return nil, fmt.Errorf("commit for height %v not found", state.LastBlockHeight) } - cs.LastCommit = lastPrecommits + vs := commit.ToVoteSet(state.ChainID, state.LastValidators) + if !vs.HasTwoThirdsMajority() { + return nil, errors.New("commit does not have +2/3 majority") + } + return vs, nil } // Updates State and increments height to match that of state. @@ -678,7 +709,11 @@ func (cs *State) updateToState(state sm.State) { cs.ValidRound = -1 cs.ValidBlock = nil cs.ValidBlockParts = nil - cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators) + if state.ConsensusParams.ABCI.VoteExtensionsEnabled(height) { + cs.Votes = cstypes.NewExtendedHeightVoteSet(state.ChainID, height, validators) + } else { + cs.Votes = cstypes.NewHeightVoteSet(state.ChainID, height, validators) + } cs.CommitRound = -1 cs.LastValidators = state.LastValidators cs.TriggeredTimeoutPrecommit = false @@ -1656,8 +1691,12 @@ func (cs *State) finalizeCommit(height int64) { if cs.blockStore.Height() < block.Height { // NOTE: the seenCommit is local justification to commit this block, // but may differ from the LastCommit included in the next block - precommits := cs.Votes.Precommits(cs.CommitRound) - cs.blockStore.SaveBlock(block, blockParts, precommits.MakeExtendedCommit()) + seenExtendedCommit := cs.Votes.Precommits(cs.CommitRound).MakeExtendedCommit() + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(block.Height) { + cs.blockStore.SaveBlockWithExtendedCommit(block, blockParts, seenExtendedCommit) + } else { + cs.blockStore.SaveBlock(block, blockParts, seenExtendedCommit.ToCommit()) + } } else { // Happens during replay if we already saved the block but didn't commit logger.Debug("calling finalizeCommit on already stored block", "height", block.Height) @@ -2068,11 +2107,43 @@ func (cs *State) addVote(vote *types.Vote, peerID p2p.ID) (added bool, err error return } - // Verify VoteExtension if precommit and not nil - // https://github.com/tendermint/tendermint/issues/8487 - if vote.Type == tmproto.PrecommitType && len(vote.BlockID.Hash) != 0 { - if err = cs.blockExec.VerifyVoteExtension(vote); err != nil { - return false, err + // Check to see if the chain is configured to extend votes. + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(cs.Height) { + // The chain is configured to extend votes, check that the vote is + // not for a nil block and verify the extensions signature against the + // corresponding public key. + + var myAddr []byte + if cs.privValidatorPubKey != nil { + myAddr = cs.privValidatorPubKey.Address() + } + // Verify VoteExtension if precommit and not nil + // https://github.com/tendermint/tendermint/issues/8487 + if vote.Type == tmproto.PrecommitType && len(vote.BlockID.Hash) != 0 && + !bytes.Equal(vote.ValidatorAddress, myAddr) { // Skip the VerifyVoteExtension call if the vote was issued by this validator. + + // The core fields of the vote message were already validated in the + // consensus reactor when the vote was received. + // Here, we verify the signature of the vote extension included in the vote + // message. + _, val := cs.state.Validators.GetByIndex(vote.ValidatorIndex) + if err := vote.VerifyExtension(cs.state.ChainID, val.PubKey); err != nil { + return false, err + } + + if err = cs.blockExec.VerifyVoteExtension(vote); err != nil { + return false, err + } + } + } else { + // Vote extensions are not enabled on the network. + // strip the extension data from the vote in case any is present. + // + // TODO punish a peer if it sent a vote with an extension when the feature + // is disabled on the network. + // https://github.com/tendermint/tendermint/issues/8565 + if stripped := vote.StripExtension(); stripped { + cs.Logger.Error("vote included extension data but vote extensions are not enabled", "peer", peerID) } } @@ -2240,11 +2311,13 @@ func (cs *State) signVote( if msgType == tmproto.PrecommitType && len(vote.BlockID.Hash) != 0 { // if the signedMessage type is for a non-nil precommit, add // VoteExtension - ext, err := cs.blockExec.ExtendVote(vote) - if err != nil { - return nil, err + if cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(cs.Height) { + ext, err := cs.blockExec.ExtendVote(vote) + if err != nil { + return nil, err + } + vote.Extension = ext } - vote.Extension = ext } v := vote.ToProto() err := cs.privValidator.SignVote(cs.state.ChainID, v) @@ -2299,14 +2372,17 @@ func (cs *State) signAddVote( // TODO: pass pubKey to signVote vote, err := cs.signVote(msgType, hash, header) - if err == nil { - cs.sendInternalMessage(msgInfo{&VoteMessage{vote}, ""}) - cs.Logger.Debug("signed and pushed vote", "height", cs.Height, "round", cs.Round, "vote", vote) - return vote + if err != nil { + cs.Logger.Error("failed signing vote", "height", cs.Height, "round", cs.Round, "vote", vote, "err", err) + return nil } - - cs.Logger.Error("failed signing vote", "height", cs.Height, "round", cs.Round, "vote", vote, "err", err) - return nil + if !cs.state.ConsensusParams.ABCI.VoteExtensionsEnabled(vote.Height) { + // The signer will sign the extension, make sure to remove the data on the way out + vote.StripExtension() + } + cs.sendInternalMessage(msgInfo{&VoteMessage{vote}, ""}) + cs.Logger.Debug("signed and pushed vote", "height", cs.Height, "round", cs.Round, "vote", vote) + return vote } // updatePrivValidatorPubKey get's the private validator public key and diff --git a/consensus/state_test.go b/consensus/state_test.go index 29fad863f..af473bd1c 100644 --- a/consensus/state_test.go +++ b/consensus/state_test.go @@ -1418,74 +1418,98 @@ func TestProcessProposalAccept(t *testing.T) { } } -// TestExtendVoteCalled tests that the vote extension methods are called at the -// correct point in the consensus algorithm. -func TestExtendVoteCalled(t *testing.T) { - m := abcimocks.NewApplication(t) - m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) - m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) - m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ - VoteExtension: []byte("extension"), - }, nil) - m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ - Status: abci.ResponseVerifyVoteExtension_ACCEPT, - }, nil) - m.On("Commit", mock.Anything, mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() - m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() - cs1, vss := randStateWithApp(4, m) - height, round := cs1.Height, cs1.Round +// TestExtendVoteCalledWhenEnabled tests that the vote extension methods are called at the +// correct point in the consensus algorithm when vote extensions are enabled. +func TestExtendVoteCalledWhenEnabled(t *testing.T) { + for _, testCase := range []struct { + name string + enabled bool + }{ + { + name: "enabled", + enabled: true, + }, + { + name: "disabled", + enabled: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + m := abcimocks.NewApplication(t) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{Status: abci.ResponseProcessProposal_ACCEPT}, nil) + if testCase.enabled { + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{ + VoteExtension: []byte("extension"), + }, nil) + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil) + } + m.On("Commit", mock.Anything, mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + height := int64(1) + if !testCase.enabled { + height = 0 + } + cs1, vss := randStateWithAppWithHeight(4, m, height) - proposalCh := subscribe(cs1.eventBus, types.EventQueryCompleteProposal) - newRoundCh := subscribe(cs1.eventBus, types.EventQueryNewRound) - pv1, err := cs1.privValidator.GetPubKey() - require.NoError(t, err) - addr := pv1.Address() - voteCh := subscribeToVoter(cs1, addr) + height, round := cs1.Height, cs1.Round - startTestRound(cs1, cs1.Height, round) - ensureNewRound(newRoundCh, height, round) - ensureNewProposal(proposalCh, height, round) + proposalCh := subscribe(cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey() + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(cs1, addr) - m.AssertNotCalled(t, "ExtendVote", mock.Anything) + startTestRound(cs1, cs1.Height, round) + ensureNewRound(newRoundCh, height, round) + ensureNewProposal(proposalCh, height, round) - rs := cs1.GetRoundState() + m.AssertNotCalled(t, "ExtendVote", mock.Anything) - blockID := types.BlockID{ - Hash: rs.ProposalBlock.Hash(), - PartSetHeader: rs.ProposalBlockParts.Header(), - } - signAddVotes(cs1, tmproto.PrevoteType, blockID.Hash, blockID.PartSetHeader, vss[1:]...) - ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) + rs := cs1.GetRoundState() - ensurePrecommit(voteCh, height, round) + blockID := types.BlockID{ + Hash: rs.ProposalBlock.Hash(), + PartSetHeader: rs.ProposalBlockParts.Header(), + } + signAddVotes(cs1, tmproto.PrevoteType, blockID.Hash, blockID.PartSetHeader, vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) - m.AssertCalled(t, "ExtendVote", context.TODO(), &abci.RequestExtendVote{ - Height: height, - Hash: blockID.Hash, - }) + ensurePrecommit(voteCh, height, round) - m.AssertCalled(t, "VerifyVoteExtension", context.TODO(), &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) + if testCase.enabled { + m.AssertCalled(t, "ExtendVote", context.TODO(), &abci.RequestExtendVote{ + Height: height, + Hash: blockID.Hash, + }) + } else { + m.AssertNotCalled(t, "ExtendVote", mock.Anything, mock.Anything) + } - signAddVotes(cs1, tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader, vss[1:]...) - ensureNewRound(newRoundCh, height+1, 0) - m.AssertExpectations(t) + signAddVotes(cs1, tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader, vss[1:]...) + ensureNewRound(newRoundCh, height+1, 0) + m.AssertExpectations(t) - // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold - // is observed by the consensus engine. - for _, pv := range vss[:3] { - pv, err := pv.GetPubKey() - require.NoError(t, err) - addr := pv.Address() - m.AssertCalled(t, "VerifyVoteExtension", context.TODO(), &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), + // Only 3 of the vote extensions are seen, as consensus proceeds as soon as the +2/3 threshold + // is observed by the consensus engine. + for _, pv := range vss[1:3] { + pv, err := pv.GetPubKey() + require.NoError(t, err) + addr := pv.Address() + if testCase.enabled { + m.AssertCalled(t, "VerifyVoteExtension", context.TODO(), &abci.RequestVerifyVoteExtension{ + Hash: blockID.Hash, + ValidatorAddress: addr, + Height: height, + VoteExtension: []byte("extension"), + }) + } else { + m.AssertNotCalled(t, "VerifyVoteExtension", mock.Anything, mock.Anything) + } + } }) } @@ -1507,6 +1531,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { m.On("Commit", mock.Anything, mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() cs1, vss := randStateWithApp(4, m) height, round := cs1.Height, cs1.Round + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = cs1.Height proposalCh := subscribe(cs1.eventBus, types.EventQueryCompleteProposal) newRoundCh := subscribe(cs1.eventBus, types.EventQueryNewRound) @@ -1524,7 +1549,7 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { Hash: rs.ProposalBlock.Hash(), PartSetHeader: rs.ProposalBlockParts.Header(), } - signAddVotes(cs1, tmproto.PrevoteType, blockID.Hash, blockID.PartSetHeader, vss[2:]...) + signAddVotes(cs1, tmproto.PrevoteType, blockID.Hash, blockID.PartSetHeader, vss...) ensurePrevoteMatch(t, voteCh, height, round, blockID.Hash) ensurePrecommit(voteCh, height, round) @@ -1534,13 +1559,6 @@ func TestVerifyVoteExtensionNotCalledOnAbsentPrecommit(t *testing.T) { Hash: blockID.Hash, }) - m.AssertCalled(t, "VerifyVoteExtension", context.TODO(), &abci.RequestVerifyVoteExtension{ - Hash: blockID.Hash, - ValidatorAddress: addr, - Height: height, - VoteExtension: []byte("extension"), - }) - signAddVotes(cs1, tmproto.PrecommitType, blockID.Hash, blockID.PartSetHeader, vss[2:]...) ensureNewRound(newRoundCh, height+1, 0) m.AssertExpectations(t) @@ -1725,6 +1743,123 @@ func TestFinalizeBlockCalled(t *testing.T) { } } +// TestVoteExtensionEnableHeight tests that 'ExtensionRequireHeight' correctly +// enforces that vote extensions be present in consensus for heights greater than +// or equal to the configured value. +func TestVoteExtensionEnableHeight(t *testing.T) { + for _, testCase := range []struct { + name string + enableHeight int64 + hasExtension bool + expectExtendCalled bool + expectVerifyCalled bool + expectSuccessfulRound bool + }{ + { + name: "extension present but not enabled", + hasExtension: true, + enableHeight: 0, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + { + name: "extension absent but not required", + hasExtension: false, + enableHeight: 0, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + { + name: "extension present and required", + hasExtension: true, + enableHeight: 1, + expectExtendCalled: true, + expectVerifyCalled: true, + expectSuccessfulRound: true, + }, + { + name: "extension absent but required", + hasExtension: false, + enableHeight: 1, + expectExtendCalled: true, + expectVerifyCalled: false, + expectSuccessfulRound: false, + }, + { + name: "extension absent but required in future height", + hasExtension: false, + enableHeight: 2, + expectExtendCalled: false, + expectVerifyCalled: false, + expectSuccessfulRound: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + numValidators := 3 + m := abcimocks.NewApplication(t) + m.On("ProcessProposal", mock.Anything, mock.Anything).Return(&abci.ResponseProcessProposal{ + Status: abci.ResponseProcessProposal_ACCEPT, + }, nil) + m.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + if testCase.expectExtendCalled { + m.On("ExtendVote", mock.Anything, mock.Anything).Return(&abci.ResponseExtendVote{}, nil) + } + if testCase.expectVerifyCalled { + m.On("VerifyVoteExtension", mock.Anything, mock.Anything).Return(&abci.ResponseVerifyVoteExtension{ + Status: abci.ResponseVerifyVoteExtension_ACCEPT, + }, nil).Times(numValidators - 1) + } + m.On("FinalizeBlock", mock.Anything, mock.Anything).Return(&abci.ResponseFinalizeBlock{}, nil).Maybe() + m.On("Commit", mock.Anything, mock.Anything).Return(&abci.ResponseCommit{}, nil).Maybe() + cs1, vss := randStateWithAppWithHeight(numValidators, m, testCase.enableHeight) + cs1.state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.enableHeight + height, round := cs1.Height, cs1.Round + + timeoutCh := subscribe(cs1.eventBus, types.EventQueryTimeoutPropose) + proposalCh := subscribe(cs1.eventBus, types.EventQueryCompleteProposal) + newRoundCh := subscribe(cs1.eventBus, types.EventQueryNewRound) + pv1, err := cs1.privValidator.GetPubKey() + require.NoError(t, err) + addr := pv1.Address() + voteCh := subscribeToVoter(cs1, addr) + + startTestRound(cs1, cs1.Height, round) + ensureNewRound(newRoundCh, height, round) + ensureNewProposal(proposalCh, height, round) + rs := cs1.GetRoundState() + + // sign all of the votes + signAddVotes(cs1, tmproto.PrevoteType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), vss[1:]...) + ensurePrevoteMatch(t, voteCh, height, round, rs.ProposalBlock.Hash()) + + var ext []byte + if testCase.hasExtension { + ext = []byte("extension") + } + + for _, vs := range vss[1:] { + vote, err := vs.signVote(tmproto.PrecommitType, rs.ProposalBlock.Hash(), rs.ProposalBlockParts.Header(), ext) + if !testCase.hasExtension { + vote.ExtensionSignature = nil + } + require.NoError(t, err) + addVotes(cs1, vote) + } + if testCase.expectSuccessfulRound { + ensurePrecommit(voteCh, height, round) + height++ + ensureNewRound(newRoundCh, height, round) + } else { + ensureNoNewTimeout(timeoutCh, cs1.config.Precommit(round).Nanoseconds()) + } + + m.AssertExpectations(t) + }) + } +} + // 4 vals, 3 Nil Precommits at P0 // What we want: // P0 waits for timeoutPrecommit before starting next round diff --git a/consensus/types/height_vote_set.go b/consensus/types/height_vote_set.go index 6a5c0b495..4cdd0234d 100644 --- a/consensus/types/height_vote_set.go +++ b/consensus/types/height_vote_set.go @@ -39,9 +39,10 @@ We let each peer provide us with up to 2 unexpected "catchup" rounds. One for their LastCommit round, and another for the official commit round. */ type HeightVoteSet struct { - chainID string - height int64 - valSet *types.ValidatorSet + chainID string + height int64 + valSet *types.ValidatorSet + extensionsEnabled bool mtx sync.Mutex round int32 // max tracked round @@ -51,7 +52,17 @@ type HeightVoteSet struct { func NewHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { hvs := &HeightVoteSet{ - chainID: chainID, + chainID: chainID, + extensionsEnabled: false, + } + hvs.Reset(height, valSet) + return hvs +} + +func NewExtendedHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { + hvs := &HeightVoteSet{ + chainID: chainID, + extensionsEnabled: true, } hvs.Reset(height, valSet) return hvs @@ -105,7 +116,12 @@ func (hvs *HeightVoteSet) addRound(round int32) { } // log.Debug("addRound(round)", "round", round) prevotes := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrevoteType, hvs.valSet) - precommits := types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + var precommits *types.VoteSet + if hvs.extensionsEnabled { + precommits = types.NewExtendedVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } else { + precommits = types.NewVoteSet(hvs.chainID, hvs.height, round, tmproto.PrecommitType, hvs.valSet) + } hvs.roundVoteSets[round] = RoundVoteSet{ Prevotes: prevotes, Precommits: precommits, diff --git a/consensus/types/height_vote_set_test.go b/consensus/types/height_vote_set_test.go index 6e07c4c27..b5db6ce19 100644 --- a/consensus/types/height_vote_set_test.go +++ b/consensus/types/height_vote_set_test.go @@ -26,7 +26,7 @@ func TestMain(m *testing.M) { func TestPeerCatchupRounds(t *testing.T) { valSet, privVals := types.RandValidatorSet(10, 1) - hvs := NewHeightVoteSet(test.DefaultTestChainID, 1, valSet) + hvs := NewExtendedHeightVoteSet(test.DefaultTestChainID, 1, valSet) vote999_0 := makeVoteHR(t, 1, 0, 999, privVals) added, err := hvs.AddVote(vote999_0, "peer1") diff --git a/evidence/pool_test.go b/evidence/pool_test.go index 514394d6e..8b892eb8c 100644 --- a/evidence/pool_test.go +++ b/evidence/pool_test.go @@ -197,7 +197,7 @@ func TestEvidencePoolUpdate(t *testing.T) { val, evidenceChainID) require.NoError(t, err) lastExtCommit := makeExtCommit(height, val.PrivKey.PubKey().Address()) - block := types.MakeBlock(height+1, []types.Tx{}, lastExtCommit.StripExtensions(), []types.Evidence{ev}) + block := types.MakeBlock(height+1, []types.Tx{}, lastExtCommit.ToCommit(), []types.Evidence{ev}) // update state (partially) state.LastBlockHeight = height + 1 state.LastBlockTime = defaultEvidenceTime.Add(22 * time.Minute) @@ -415,7 +415,7 @@ func initializeBlockStore(db dbm.DB, state sm.State, valAddr []byte) (*store.Blo for i := int64(1); i <= state.LastBlockHeight; i++ { lastCommit := makeExtCommit(i-1, valAddr) - block := state.MakeBlock(i, test.MakeNTxs(i, 1), lastCommit.StripExtensions(), nil, state.Validators.Proposer.Address) + block := state.MakeBlock(i, test.MakeNTxs(i, 1), lastCommit.ToCommit(), nil, state.Validators.Proposer.Address) block.Header.Time = defaultEvidenceTime.Add(time.Duration(i) * time.Minute) block.Header.Version = tmversion.Consensus{Block: version.BlockProtocol, App: 1} const parts = 1 @@ -425,7 +425,7 @@ func initializeBlockStore(db dbm.DB, state sm.State, valAddr []byte) (*store.Blo } seenCommit := makeExtCommit(i, valAddr) - blockStore.SaveBlock(block, partSet, seenCommit) + blockStore.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } return blockStore, nil @@ -441,6 +441,7 @@ func makeExtCommit(height int64, valAddr []byte) *types.ExtendedCommit { Timestamp: defaultEvidenceTime, Signature: []byte("Signature"), }, + ExtensionSignature: []byte("Extended Signature"), }}, } } diff --git a/evidence/verify_test.go b/evidence/verify_test.go index 3380de313..71bbac986 100644 --- a/evidence/verify_test.go +++ b/evidence/verify_test.go @@ -207,10 +207,10 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { // we are simulating a duplicate vote attack where all the validators in the conflictingVals set // except the last validator vote twice blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := test.MakeExtendedCommitFromVoteSet(blockID, voteSet, conflictingPrivVals[:4], defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ SignedHeader: &types.SignedHeader{ @@ -226,10 +226,10 @@ func TestVerifyLightClientAttack_Equivocation(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := test.MakeExtendedCommitFromVoteSet(trustedBlockID, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trustedSignedHeader := &types.SignedHeader{ Header: trustedHeader, Commit: trustedCommit, @@ -293,10 +293,10 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { // we are simulating an amnesia attack where all the validators in the conflictingVals set // except the last validator vote twice. However this time the commits are of different rounds. blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, 10, 0, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 0, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := test.MakeExtendedCommitFromVoteSet(blockID, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev := &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ SignedHeader: &types.SignedHeader{ @@ -312,10 +312,10 @@ func TestVerifyLightClientAttack_Amnesia(t *testing.T) { } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) - trustedVoteSet := types.NewVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, 10, 1, tmproto.SignedMsgType(2), conflictingVals) trustedExtCommit, err := test.MakeExtendedCommitFromVoteSet(trustedBlockID, trustedVoteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trustedSignedHeader := &types.SignedHeader{ Header: trustedHeader, Commit: trustedCommit, @@ -487,10 +487,10 @@ func makeLunaticEvidence( conflictingHeader.ValidatorsHash = conflictingVals.Hash() blockID := makeBlockID(conflictingHeader.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) + voteSet := types.NewExtendedVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), conflictingVals) extCommit, err := test.MakeExtendedCommitFromVoteSet(blockID, voteSet, conflictingPrivVals, defaultEvidenceTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev = &types.LightClientAttackEvidence{ ConflictingBlock: &types.LightBlock{ SignedHeader: &types.SignedHeader{ @@ -515,10 +515,10 @@ func makeLunaticEvidence( } trustedBlockID := makeBlockID(trustedHeader.Hash(), 1000, []byte("partshash")) trustedVals, privVals := types.RandValidatorSet(totalVals, defaultVotingPower) - trustedVoteSet := types.NewVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) + trustedVoteSet := types.NewExtendedVoteSet(evidenceChainID, height, 1, tmproto.SignedMsgType(2), trustedVals) trustedExtCommit, err := test.MakeExtendedCommitFromVoteSet(trustedBlockID, trustedVoteSet, privVals, defaultEvidenceTime) require.NoError(t, err) - trustedCommit := trustedExtCommit.StripExtensions() + trustedCommit := trustedExtCommit.ToCommit() trusted = &types.LightBlock{ SignedHeader: &types.SignedHeader{ Header: trustedHeader, diff --git a/internal/test/params.go b/internal/test/params.go new file mode 100644 index 000000000..a45a6fb58 --- /dev/null +++ b/internal/test/params.go @@ -0,0 +1,13 @@ +package test + +import ( + "github.com/tendermint/tendermint/types" +) + +// ConsensusParams returns a default set of ConsensusParams that are suitable +// for use in testing +func ConsensusParams() *types.ConsensusParams { + c := types.DefaultConsensusParams() + c.ABCI.VoteExtensionsEnableHeight = 1 + return c +} diff --git a/state/execution.go b/state/execution.go index c67692ac7..7a27694f9 100644 --- a/state/execution.go +++ b/state/execution.go @@ -3,6 +3,7 @@ package state import ( "bytes" "context" + "errors" "fmt" "time" @@ -113,14 +114,13 @@ func (blockExec *BlockExecutor) CreateProposalBlock( maxDataBytes := types.MaxDataBytes(maxBytes, evSize, state.Validators.Size()) txs := blockExec.mempool.ReapMaxBytesMaxGas(maxDataBytes, maxGas) - commit := lastExtCommit.StripExtensions() + commit := lastExtCommit.ToCommit() block := state.MakeBlock(height, txs, commit, evidence, proposerAddr) - rpp, err := blockExec.proxyApp.PrepareProposal(context.TODO(), &abci.RequestPrepareProposal{ MaxTxBytes: maxDataBytes, Txs: block.Txs.ToSliceOfBytes(), - LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight), + LocalLastCommit: buildExtendedCommitInfo(lastExtCommit, blockExec.store, state.InitialHeight, state.ConsensusParams.ABCI), Misbehavior: block.Evidence.Evidence.ToABCI(), Height: block.Height, Time: block.Time, @@ -320,7 +320,7 @@ func (blockExec *BlockExecutor) VerifyVoteExtension(vote *types.Vote) error { } if !resp.IsAccepted() { - return types.ErrVoteInvalidExtension + return errors.New("invalid vote extension") } return nil @@ -426,7 +426,7 @@ func buildLastCommitInfo(block *types.Block, store Store, initialHeight int64) a // data, it returns an empty record. // // Assumes that the commit signatures are sorted according to validator index. -func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64) abci.ExtendedCommitInfo { +func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeight int64, ap types.ABCIParams) abci.ExtendedCommitInfo { if ec.Height < initialHeight { // There are no extended commits for heights below the initial height. return abci.ExtendedCommitInfo{} @@ -464,9 +464,15 @@ func buildExtendedCommitInfo(ec *types.ExtendedCommit, store Store, initialHeigh } var ext []byte - if ecs.BlockIDFlag == types.BlockIDFlagCommit { - // We only care about vote extensions if a validator has voted to - // commit. + // Check if vote extensions were enabled during the commit's height: ec.Height. + // ec is the commit from the previous height, so if extensions were enabled + // during that height, we ensure they are present and deliver the data to + // the proposer. If they were not enabled during this previous height, we + // will not deliver extension data. + if ap.VoteExtensionsEnabled(ec.Height) && ecs.BlockIDFlag == types.BlockIDFlagCommit { + if err := ecs.EnsureExtension(); err != nil { + panic(fmt.Errorf("commit at height %d received with missing vote extensions data", ec.Height)) + } ext = ecs.Extension } diff --git a/state/execution_test.go b/state/execution_test.go index 22070a009..46207ea2b 100644 --- a/state/execution_test.go +++ b/state/execution_test.go @@ -139,7 +139,7 @@ func TestFinalizeBlockDecidedLastCommit(t *testing.T) { } // block for height 2 - block := makeBlock(state, 2, lastCommit.StripExtensions()) + block := makeBlock(state, 2, lastCommit.ToCommit()) bps, err := block.MakePartSet(testPartSize) require.NoError(t, err) blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: bps.Header()} @@ -175,41 +175,49 @@ func TestFinalizeBlockValidators(t *testing.T) { var ( now = tmtime.Now() - commitSig0 = types.CommitSig{ - BlockIDFlag: types.BlockIDFlagCommit, - ValidatorAddress: state.Validators.Validators[0].Address, - Timestamp: now, - Signature: []byte("Signature1"), + commitSig0 = types.ExtendedCommitSig{ + CommitSig: types.CommitSig{ + BlockIDFlag: types.BlockIDFlagCommit, + ValidatorAddress: state.Validators.Validators[0].Address, + Timestamp: now, + Signature: []byte("Signature1"), + }, + Extension: []byte("extension1"), + ExtensionSignature: []byte("extensionSig1"), } - commitSig1 = types.CommitSig{ - BlockIDFlag: types.BlockIDFlagCommit, - ValidatorAddress: state.Validators.Validators[1].Address, - Timestamp: now, - Signature: []byte("Signature2"), + commitSig1 = types.ExtendedCommitSig{ + CommitSig: types.CommitSig{ + BlockIDFlag: types.BlockIDFlagCommit, + ValidatorAddress: state.Validators.Validators[1].Address, + Timestamp: now, + Signature: []byte("Signature2"), + }, + Extension: []byte("extension2"), + ExtensionSignature: []byte("extensionSig2"), } - absentSig = types.NewCommitSigAbsent() + absentSig = types.NewExtendedCommitSigAbsent() ) testCases := []struct { desc string - lastCommitSigs []types.CommitSig + lastCommitSigs []types.ExtendedCommitSig expectedAbsentValidators []int }{ - {"none absent", []types.CommitSig{commitSig0, commitSig1}, []int{}}, - {"one absent", []types.CommitSig{commitSig0, absentSig}, []int{1}}, - {"multiple absent", []types.CommitSig{absentSig, absentSig}, []int{0, 1}}, + {"none absent", []types.ExtendedCommitSig{commitSig0, commitSig1}, []int{}}, + {"one absent", []types.ExtendedCommitSig{commitSig0, absentSig}, []int{1}}, + {"multiple absent", []types.ExtendedCommitSig{absentSig, absentSig}, []int{0, 1}}, } for _, tc := range testCases { - lastCommit := &types.Commit{ - Height: 1, - BlockID: prevBlockID, - Signatures: tc.lastCommitSigs, + lastCommit := &types.ExtendedCommit{ + Height: 1, + BlockID: prevBlockID, + ExtendedSignatures: tc.lastCommitSigs, } // block for height 2 - block := makeBlock(state, 2, lastCommit) + block := makeBlock(state, 2, lastCommit.ToCommit()) _, err = sm.ExecCommitBlock(proxyApp.Consensus(), block, log.TestingLogger(), stateStore, 1) require.Nil(t, err, tc.desc) @@ -933,6 +941,110 @@ func TestPrepareProposalErrorOnPrepareProposalError(t *testing.T) { mp.AssertExpectations(t) } +// TestCreateProposalBlockPanicOnAbsentVoteExtensions ensures that the CreateProposalBlock +// call correctly panics when the vote extension data is missing from the extended commit +// data that the method receives. +func TestCreateProposalAbsentVoteExtensions(t *testing.T) { + for _, testCase := range []struct { + name string + + // The height that is about to be proposed + height int64 + + // The first height during which vote extensions will be required for consensus to proceed. + extensionEnableHeight int64 + expectPanic bool + }{ + { + name: "missing extension data on first required height", + height: 2, + extensionEnableHeight: 1, + expectPanic: true, + }, + { + name: "missing extension during before required height", + height: 2, + extensionEnableHeight: 2, + expectPanic: false, + }, + { + name: "missing extension data and not required", + height: 2, + extensionEnableHeight: 0, + expectPanic: false, + }, + { + name: "missing extension data and required in two heights", + height: 2, + extensionEnableHeight: 3, + expectPanic: false, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + app := abcimocks.NewApplication(t) + if !testCase.expectPanic { + app.On("PrepareProposal", mock.Anything, mock.Anything).Return(&abci.ResponsePrepareProposal{}, nil) + } + cc := proxy.NewLocalClientCreator(app) + proxyApp := proxy.NewAppConns(cc, proxy.NopMetrics()) + err := proxyApp.Start() + require.NoError(t, err) + + state, stateDB, privVals := makeState(1, int(testCase.height-1)) + stateStore := sm.NewStore(stateDB, sm.StoreOptions{ + DiscardABCIResponses: false, + }) + state.ConsensusParams.ABCI.VoteExtensionsEnableHeight = testCase.extensionEnableHeight + mp := &mpmocks.Mempool{} + mp.On("Lock").Return() + mp.On("Unlock").Return() + mp.On("FlushAppConn", mock.Anything).Return(nil) + mp.On("Update", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything).Return(nil) + mp.On("ReapMaxBytesMaxGas", mock.Anything, mock.Anything).Return(types.Txs{}) + + blockStore := store.NewBlockStore(dbm.NewMemDB()) + blockExec := sm.NewBlockExecutor( + stateStore, + log.NewNopLogger(), + proxyApp.Consensus(), + mp, + sm.EmptyEvidencePool{}, + blockStore, + ) + block := makeBlock(state, testCase.height, new(types.Commit)) + + bps, err := block.MakePartSet(testPartSize) + require.NoError(t, err) + blockID := types.BlockID{Hash: block.Hash(), PartSetHeader: bps.Header()} + pa, _ := state.Validators.GetByIndex(0) + lastCommit, _, _ := makeValidCommit(testCase.height-1, blockID, state.Validators, privVals) + stripSignatures(lastCommit) + if testCase.expectPanic { + require.Panics(t, func() { + blockExec.CreateProposalBlock(testCase.height, state, lastCommit, pa) //nolint:errcheck + }) + } else { + _, err = blockExec.CreateProposalBlock(testCase.height, state, lastCommit, pa) + require.NoError(t, err) + } + }) + } +} + +func stripSignatures(ec *types.ExtendedCommit) { + for i, commitSig := range ec.ExtendedSignatures { + commitSig.Extension = nil + commitSig.ExtensionSignature = nil + ec.ExtendedSignatures[i] = commitSig + } +} + func makeBlockID(hash []byte, partSetSize uint32, partSetHash []byte) types.BlockID { var ( h = make([]byte, tmhash.Size) diff --git a/state/mocks/block_store.go b/state/mocks/block_store.go index 4e88d9ebe..e5472eb01 100644 --- a/state/mocks/block_store.go +++ b/state/mocks/block_store.go @@ -229,7 +229,12 @@ func (_m *BlockStore) PruneBlocks(height int64, _a1 state.State) (uint64, int64, } // SaveBlock provides a mock function with given fields: block, blockParts, seenCommit -func (_m *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (_m *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { + _m.Called(block, blockParts, seenCommit) +} + +// SaveBlockWithExtendedCommit provides a mock function with given fields: block, blockParts, seenCommit +func (_m *BlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { _m.Called(block, blockParts, seenCommit) } diff --git a/state/rollback_test.go b/state/rollback_test.go index 3e8f33c82..9e2d03efc 100644 --- a/state/rollback_test.go +++ b/state/rollback_test.go @@ -118,7 +118,7 @@ func TestRollbackHard(t *testing.T) { partSet, err := block.MakePartSet(types.BlockPartSizeBytes) require.NoError(t, err) - blockStore.SaveBlock(block, partSet, &types.ExtendedCommit{Height: block.Height}) + blockStore.SaveBlock(block, partSet, &types.Commit{Height: block.Height}) currState := state.State{ Version: tmstate.Version{ @@ -160,7 +160,7 @@ func TestRollbackHard(t *testing.T) { nextPartSet, err := nextBlock.MakePartSet(types.BlockPartSizeBytes) require.NoError(t, err) - blockStore.SaveBlock(nextBlock, nextPartSet, &types.ExtendedCommit{Height: nextBlock.Height}) + blockStore.SaveBlock(nextBlock, nextPartSet, &types.Commit{Height: nextBlock.Height}) rollbackHeight, rollbackHash, err := state.Rollback(blockStore, stateStore, true) require.NoError(t, err) @@ -173,7 +173,7 @@ func TestRollbackHard(t *testing.T) { require.Equal(t, currState, loadedState) // resave the same block - blockStore.SaveBlock(nextBlock, nextPartSet, &types.ExtendedCommit{Height: nextBlock.Height}) + blockStore.SaveBlock(nextBlock, nextPartSet, &types.Commit{Height: nextBlock.Height}) params.Version.App = 11 diff --git a/state/services.go b/state/services.go index 5553d160b..5230f8715 100644 --- a/state/services.go +++ b/state/services.go @@ -24,7 +24,8 @@ type BlockStore interface { LoadBlockMeta(height int64) *types.BlockMeta LoadBlock(height int64) *types.Block - SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) + SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) + SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) PruneBlocks(height int64, state State) (uint64, int64, error) diff --git a/state/store.go b/state/store.go index b6096adf8..002fbd1ce 100644 --- a/state/store.go +++ b/state/store.go @@ -1,10 +1,12 @@ package state import ( + "encoding/binary" "errors" "fmt" "github.com/cosmos/gogoproto/proto" + "github.com/google/orderedcode" dbm "github.com/tendermint/tm-db" abci "github.com/tendermint/tendermint/abci/types" @@ -37,6 +39,18 @@ func calcABCIResponsesKey(height int64) []byte { return []byte(fmt.Sprintf("abciResponsesKey:%v", height)) } +var tmpABCIKey []byte + +func init() { + var err error + // temporary extra key before consensus param protos are regenerated + // TODO(wbanfield) remove in next PR + tmpABCIKey, err = orderedcode.Append(nil, int64(10000)) + if err != nil { + panic(err) + } +} + //---------------------- var lastABCIResponseKey = []byte("lastABCIResponseKey") @@ -162,6 +176,12 @@ func (store dbStore) loadState(key []byte) (state State, err error) { if err != nil { return state, err } + buf, err = store.db.Get(tmpABCIKey) + if err != nil { + return state, err + } + h, _ := binary.Varint(buf) + sm.ConsensusParams.ABCI.VoteExtensionsEnableHeight = h return *sm, nil } @@ -197,6 +217,12 @@ func (store dbStore) save(state State, key []byte) error { if err != nil { return err } + bz := make([]byte, 5) + binary.PutVarint(bz, state.ConsensusParams.ABCI.VoteExtensionsEnableHeight) + if err := store.db.SetSync(tmpABCIKey, bz); err != nil { + return err + } + return nil } diff --git a/state/validation_test.go b/state/validation_test.go index 443e2ea85..391b847f1 100644 --- a/state/validation_test.go +++ b/state/validation_test.go @@ -112,7 +112,7 @@ func TestValidateBlockHeader(t *testing.T) { state, _, lastExtCommit, err = makeAndCommitGoodBlock( state, height, lastCommit, state.Validators.GetProposer().Address, blockExec, privVals, nil) require.NoError(t, err, "height %d", height) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() } } @@ -210,7 +210,7 @@ func TestValidateBlockCommit(t *testing.T) { nil, ) require.NoError(t, err, "height %d", height) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() /* wrongSigsCommit is fine except for the extra bad precommit @@ -352,7 +352,7 @@ func TestValidateBlockEvidence(t *testing.T) { evidence, ) require.NoError(t, err, "height %d", height) - lastCommit = lastExtCommit.StripExtensions() + lastCommit = lastExtCommit.ToCommit() } } diff --git a/store/store.go b/store/store.go index 3a33d6cac..660c1d415 100644 --- a/store/store.go +++ b/store/store.go @@ -1,6 +1,7 @@ package store import ( + "errors" "fmt" "strconv" @@ -242,6 +243,9 @@ func (bs *BlockStore) LoadBlockCommit(height int64) *types.Commit { return commit } +// LoadExtendedCommit returns the ExtendedCommit for the given height. +// The extended commit is not guaranteed to contain the same +2/3 precommits data +// as the commit in the block. func (bs *BlockStore) LoadBlockExtendedCommit(height int64) *types.ExtendedCommit { var pbec = new(tmproto.ExtendedCommit) bz, err := bs.db.Get(calcExtCommitKey(height)) @@ -387,7 +391,46 @@ func (bs *BlockStore) PruneBlocks(height int64, state sm.State) (uint64, int64, // If all the nodes restart after committing a block, // we need this to reload the precommits to catch-up nodes to the // most recent height. Otherwise they'd stall at H-1. -func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.ExtendedCommit) { +func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) { + if block == nil { + panic("BlockStore can only save a non-nil block") + } + if err := bs.saveBlockToBatch(block, blockParts, seenCommit); err != nil { + panic(err) + } + + // Save new BlockStoreState descriptor. This also flushes the database. + bs.saveState() +} + +// SaveBlockWithExtendedCommit persists the given block, blockParts, and +// seenExtendedCommit to the underlying db. seenExtendedCommit is stored under +// two keys in the database: as the seenCommit and as the ExtendedCommit data for the +// height. This allows the vote extension data to be persisted for all blocks +// that are saved. +func (bs *BlockStore) SaveBlockWithExtendedCommit(block *types.Block, blockParts *types.PartSet, seenExtendedCommit *types.ExtendedCommit) { + if block == nil { + panic("BlockStore can only save a non-nil block") + } + if err := seenExtendedCommit.EnsureExtensions(); err != nil { + panic(fmt.Errorf("saving block with extensions: %w", err)) + } + if err := bs.saveBlockToBatch(block, blockParts, seenExtendedCommit.ToCommit()); err != nil { + panic(err) + } + height := block.Height + + pbec := seenExtendedCommit.ToProto() + extCommitBytes := mustEncode(pbec) + if err := bs.db.Set(calcExtCommitKey(height), extCommitBytes); err != nil { + panic(err) + } + + // Save new BlockStoreState descriptor. This also flushes the database. + bs.saveState() +} + +func (bs *BlockStore) saveBlockToBatch(block *types.Block, blockParts *types.PartSet, seenCommit *types.Commit) error { if block == nil { panic("BlockStore can only save a non-nil block") } @@ -396,14 +439,13 @@ func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, s hash := block.Hash() if g, w := height, bs.Height()+1; bs.Base() > 0 && g != w { - panic(fmt.Sprintf("BlockStore can only save contiguous blocks. Wanted %v, got %v", w, g)) + return fmt.Errorf("BlockStore can only save contiguous blocks. Wanted %v, got %v", w, g) } if !blockParts.IsComplete() { - panic("BlockStore can only save complete block part sets") + return errors.New("BlockStore can only save complete block part sets") } if height != seenCommit.Height { - panic(fmt.Sprintf("BlockStore cannot save seen commit of a different height (block: %d, commit: %d)", - height, seenCommit.Height)) + return fmt.Errorf("BlockStore cannot save seen commit of a different height (block: %d, commit: %d)", height, seenCommit.Height) } // Save block parts. This must be done before the block meta, since callers @@ -419,35 +461,29 @@ func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, s blockMeta := types.NewBlockMeta(block, blockParts) pbm := blockMeta.ToProto() if pbm == nil { - panic("nil blockmeta") + return errors.New("nil blockmeta") } metaBytes := mustEncode(pbm) if err := bs.db.Set(calcBlockMetaKey(height), metaBytes); err != nil { - panic(err) + return err } if err := bs.db.Set(calcBlockHashKey(hash), []byte(fmt.Sprintf("%d", height))); err != nil { - panic(err) + return err } // Save block commit (duplicate and separate from the Block) pbc := block.LastCommit.ToProto() blockCommitBytes := mustEncode(pbc) if err := bs.db.Set(calcBlockCommitKey(height-1), blockCommitBytes); err != nil { - panic(err) + return err } // Save seen commit (seen +2/3 precommits for block) // NOTE: we can delete this at a later height - pbsc := seenCommit.StripExtensions().ToProto() + pbsc := seenCommit.ToProto() seenCommitBytes := mustEncode(pbsc) if err := bs.db.Set(calcSeenCommitKey(height), seenCommitBytes); err != nil { - panic(err) - } - - pbec := seenCommit.ToProto() - extCommitBytes := mustEncode(pbec) - if err := bs.db.Set(calcExtCommitKey(height), extCommitBytes); err != nil { - panic(err) + return err } // Done! @@ -458,8 +494,7 @@ func (bs *BlockStore) SaveBlock(block *types.Block, blockParts *types.PartSet, s } bs.mtx.Unlock() - // Save new BlockStoreState descriptor. This also flushes the database. - bs.saveState() + return nil } func (bs *BlockStore) saveBlockPart(height int64, index int, part *types.Part) { diff --git a/store/store_test.go b/store/store_test.go index 31ad2d482..dc324df5f 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -40,6 +40,7 @@ func makeTestExtCommit(height int64, timestamp time.Time) *types.ExtendedCommit Timestamp: timestamp, Signature: []byte("Signature"), }, + ExtensionSignature: []byte("ExtensionSignature"), }} return &types.ExtendedCommit{ Height: height, @@ -155,7 +156,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { part2 := validPartSet.GetPart(1) seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) - bs.SaveBlock(block, validPartSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, validPartSet, seenCommit) require.EqualValues(t, 1, bs.Base(), "expecting the new height to be changed") require.EqualValues(t, block.Header.Height, bs.Height(), "expecting the new height to be changed") @@ -174,7 +175,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { // End of setup, test data - commitAtH10 := makeTestExtCommit(10, tmtime.Now()).StripExtensions() + commitAtH10 := makeTestExtCommit(10, tmtime.Now()).ToCommit() tuples := []struct { block *types.Block parts *types.PartSet @@ -207,16 +208,17 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { ChainID: "block_test", Time: tmtime.Now(), ProposerAddress: tmrand.Bytes(crypto.AddressSize)}, - makeTestExtCommit(5, tmtime.Now()).StripExtensions(), + makeTestExtCommit(5, tmtime.Now()).ToCommit(), ), parts: validPartSet, seenCommit: makeTestExtCommit(5, tmtime.Now()), }, { - block: newBlock(header1, commitAtH10), - parts: incompletePartSet, - wantPanic: "only save complete block", // incomplete parts + block: newBlock(header1, commitAtH10), + parts: incompletePartSet, + wantPanic: "only save complete block", // incomplete parts + seenCommit: makeTestExtCommit(10, tmtime.Now()), }, { @@ -245,7 +247,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { }, { - block: newBlock(header1, commitAtH10), + block: block, parts: validPartSet, seenCommit: seenCommit, @@ -254,7 +256,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { }, { - block: newBlock(header1, commitAtH10), + block: block, parts: validPartSet, seenCommit: seenCommit, @@ -276,7 +278,7 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { bs, db := newInMemoryBlockStore() // SaveBlock res, err, panicErr := doFn(func() (interface{}, error) { - bs.SaveBlock(tuple.block, tuple.parts, tuple.seenCommit) + bs.SaveBlockWithExtendedCommit(tuple.block, tuple.parts, tuple.seenCommit) if tuple.block == nil { return nil, nil } @@ -346,6 +348,91 @@ func TestBlockStoreSaveLoadBlock(t *testing.T) { } } +// TestSaveBlockWithExtendedCommitPanicOnAbsentExtension tests that saving a +// block with an extended commit panics when the extension data is absent. +func TestSaveBlockWithExtendedCommitPanicOnAbsentExtension(t *testing.T) { + for _, testCase := range []struct { + name string + malleateCommit func(*types.ExtendedCommit) + shouldPanic bool + }{ + { + name: "basic save", + malleateCommit: func(_ *types.ExtendedCommit) {}, + shouldPanic: false, + }, + { + name: "save commit with no extensions", + malleateCommit: func(c *types.ExtendedCommit) { + c.StripExtensions() + }, + shouldPanic: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + state, bs, cleanup := makeStateAndBlockStore(log.NewTMLogger(new(bytes.Buffer))) + defer cleanup() + h := bs.Height() + 1 + block := state.MakeBlock(h, test.MakeNTxs(h, 10), new(types.Commit), nil, state.Validators.GetProposer().Address) + + seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) + ps, err := block.MakePartSet(2) + require.NoError(t, err) + testCase.malleateCommit(seenCommit) + if testCase.shouldPanic { + require.Panics(t, func() { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + }) + } else { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + } + }) + } +} + +// TestLoadBlockExtendedCommit tests loading the extended commit for a previously +// saved block. The load method should return nil when only a commit was saved and +// return the extended commit otherwise. +func TestLoadBlockExtendedCommit(t *testing.T) { + for _, testCase := range []struct { + name string + saveExtended bool + expectResult bool + }{ + { + name: "save commit", + saveExtended: false, + expectResult: false, + }, + { + name: "save extended commit", + saveExtended: true, + expectResult: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + state, bs, cleanup := makeStateAndBlockStore(log.NewTMLogger(new(bytes.Buffer))) + defer cleanup() + h := bs.Height() + 1 + block := state.MakeBlock(h, test.MakeNTxs(h, 10), new(types.Commit), nil, state.Validators.GetProposer().Address) + seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) + ps, err := block.MakePartSet(2) + require.NoError(t, err) + if testCase.saveExtended { + bs.SaveBlockWithExtendedCommit(block, ps, seenCommit) + } else { + bs.SaveBlock(block, ps, seenCommit.ToCommit()) + } + res := bs.LoadBlockExtendedCommit(block.Height) + if testCase.expectResult { + require.Equal(t, seenCommit, res) + } else { + require.Nil(t, res) + } + }) + } +} + func TestLoadBaseMeta(t *testing.T) { config := test.ResetTestRoot("blockchain_reactor_test") defer os.RemoveAll(config.RootDir) @@ -361,7 +448,7 @@ func TestLoadBaseMeta(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(h, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } _, _, err = bs.PruneBlocks(4, state) @@ -445,7 +532,7 @@ func TestPruneBlocks(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(h, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) } assert.EqualValues(t, 1, bs.Base()) @@ -572,7 +659,7 @@ func TestLoadBlockMetaByHash(t *testing.T) { partSet, err := b1.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(1, tmtime.Now()) - bs.SaveBlock(b1, partSet, seenCommit) + bs.SaveBlock(b1, partSet, seenCommit.ToCommit()) baseBlock := bs.LoadBlockMetaByHash(b1.Hash()) assert.EqualValues(t, b1.Header.Height, baseBlock.Header.Height) @@ -589,7 +676,7 @@ func TestBlockFetchAtHeight(t *testing.T) { partSet, err := block.MakePartSet(2) require.NoError(t, err) seenCommit := makeTestExtCommit(block.Header.Height, tmtime.Now()) - bs.SaveBlock(block, partSet, seenCommit) + bs.SaveBlockWithExtendedCommit(block, partSet, seenCommit) require.Equal(t, bs.Height(), block.Header.Height, "expecting the new height to be changed") blockAtHeight := bs.LoadBlock(bs.Height()) diff --git a/test/e2e/runner/evidence.go b/test/e2e/runner/evidence.go index 5ac109f47..d21c2f8f4 100644 --- a/test/e2e/runner/evidence.go +++ b/test/e2e/runner/evidence.go @@ -166,8 +166,8 @@ func generateLightClientAttackEvidence( // create a commit for the forged header blockID := makeBlockID(header.Hash(), 1000, []byte("partshash")) - voteSet := types.NewVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) - commit, err := test.MakeExtendedCommitFromVoteSet(blockID, voteSet, pv, forgedTime) + voteSet := types.NewExtendedVoteSet(chainID, forgedHeight, 0, tmproto.SignedMsgType(2), conflictingVals) + ec, err := test.MakeExtendedCommitFromVoteSet(blockID, voteSet, pv, forgedTime) if err != nil { return nil, err } @@ -176,7 +176,7 @@ func generateLightClientAttackEvidence( ConflictingBlock: &types.LightBlock{ SignedHeader: &types.SignedHeader{ Header: header, - Commit: commit.StripExtensions(), + Commit: ec.ToCommit(), }, ValidatorSet: conflictingVals, }, diff --git a/test/e2e/tests/app_test.go b/test/e2e/tests/app_test.go index 7b700da5e..34ab5e14f 100644 --- a/test/e2e/tests/app_test.go +++ b/test/e2e/tests/app_test.go @@ -107,6 +107,7 @@ func TestApp_Tx(t *testing.T) { func TestApp_VoteExtensions(t *testing.T) { testNode(t, func(t *testing.T, node e2e.Node) { + t.Skip() client, err := node.Client() require.NoError(t, err) diff --git a/types/block.go b/types/block.go index de60f66ea..bf5dff51d 100644 --- a/types/block.go +++ b/types/block.go @@ -749,22 +749,23 @@ func (ecs ExtendedCommitSig) ValidateBasic() error { if len(ecs.Extension) > MaxVoteExtensionSize { return fmt.Errorf("vote extension is too big (max: %d)", MaxVoteExtensionSize) } - if len(ecs.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(ecs.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } return nil } - // We expect there to not be any vote extension or vote extension signature - // on nil or absent votes. - if len(ecs.Extension) != 0 { - return fmt.Errorf("vote extension is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + if len(ecs.ExtensionSignature) == 0 && len(ecs.Extension) != 0 { + return errors.New("vote extension signature absent on vote with extension") } - if len(ecs.ExtensionSignature) != 0 { - return fmt.Errorf("vote extension signature is present for commit sig with block ID flag %v", ecs.BlockIDFlag) + return nil +} + +// EnsureExtensions validates that a vote extensions signature is present for +// this ExtendedCommitSig. +func (ecs ExtendedCommitSig) EnsureExtension() error { + if ecs.BlockIDFlag == BlockIDFlagCommit && len(ecs.ExtensionSignature) == 0 { + return errors.New("vote extension data is missing") } return nil } @@ -908,6 +909,26 @@ func (commit *Commit) Hash() tmbytes.HexBytes { return commit.hash } +// WrappedExtendedCommit wraps a commit as an ExtendedCommit. +// The VoteExtension fields of the resulting value will by nil. +// Wrapping a Commit as an ExtendedCommit is useful when an API +// requires an ExtendedCommit wire type but does not +// need the VoteExtension data. +func (commit *Commit) WrappedExtendedCommit() *ExtendedCommit { + cs := make([]ExtendedCommitSig, len(commit.Signatures)) + for idx, s := range commit.Signatures { + cs[idx] = ExtendedCommitSig{ + CommitSig: s, + } + } + return &ExtendedCommit{ + Height: commit.Height, + Round: commit.Round, + BlockID: commit.BlockID, + ExtendedSignatures: cs, + } +} + // StringIndented returns a string representation of the commit. func (commit *Commit) StringIndented(indent string) string { if commit == nil { @@ -1005,17 +1026,33 @@ func (ec *ExtendedCommit) Clone() *ExtendedCommit { return &ecc } +// ToExtendedVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. +// Panics if any of the votes have invalid or absent vote extension data. +// Inverse of VoteSet.MakeExtendedCommit(). +func (ec *ExtendedCommit) ToExtendedVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewExtendedVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + // ToVoteSet constructs a VoteSet from the Commit and validator set. -// Panics if signatures from the commit can't be added to the voteset. +// Panics if signatures from the ExtendedCommit can't be added to the voteset. // Inverse of VoteSet.MakeExtendedCommit(). func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { voteSet := NewVoteSet(chainID, ec.Height, ec.Round, tmproto.PrecommitType, vals) + ec.addSigsToVoteSet(voteSet) + return voteSet +} + +// addSigsToVoteSet adds all of the signature to voteSet. +func (ec *ExtendedCommit) addSigsToVoteSet(voteSet *VoteSet) { for idx, ecs := range ec.ExtendedSignatures { if ecs.BlockIDFlag == BlockIDFlagAbsent { continue // OK, some precommits can be missing. } vote := ec.GetExtendedVote(int32(idx)) - if err := vote.ValidateWithExtension(); err != nil { + if err := vote.ValidateBasic(); err != nil { panic(fmt.Errorf("failed to validate vote reconstructed from LastCommit: %w", err)) } added, err := voteSet.AddVote(vote) @@ -1023,12 +1060,58 @@ func (ec *ExtendedCommit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet panic(fmt.Errorf("failed to reconstruct vote set from extended commit: %w", err)) } } +} + +// ToVoteSet constructs a VoteSet from the Commit and validator set. +// Panics if signatures from the commit can't be added to the voteset. +// Inverse of VoteSet.MakeCommit(). +func (commit *Commit) ToVoteSet(chainID string, vals *ValidatorSet) *VoteSet { + voteSet := NewVoteSet(chainID, commit.Height, commit.Round, tmproto.PrecommitType, vals) + for idx, cs := range commit.Signatures { + if cs.BlockIDFlag == BlockIDFlagAbsent { + continue // OK, some precommits can be missing. + } + vote := commit.GetVote(int32(idx)) + if err := vote.ValidateBasic(); err != nil { + panic(fmt.Errorf("failed to validate vote reconstructed from commit: %w", err)) + } + added, err := voteSet.AddVote(vote) + if !added || err != nil { + panic(fmt.Errorf("failed to reconstruct vote set from commit: %w", err)) + } + } return voteSet } -// StripExtensions converts an ExtendedCommit to a Commit by removing all vote +// EnsureExtensions validates that a vote extensions signature is present for +// every ExtendedCommitSig in the ExtendedCommit. +func (ec *ExtendedCommit) EnsureExtensions() error { + for _, ecs := range ec.ExtendedSignatures { + if err := ecs.EnsureExtension(); err != nil { + return err + } + } + return nil +} + +// StripExtensions removes all VoteExtension data from an ExtendedCommit. This +// is useful when dealing with an ExendedCommit but vote extension data is +// expected to be absent. +func (ec *ExtendedCommit) StripExtensions() bool { + stripped := false + for idx := range ec.ExtendedSignatures { + if len(ec.ExtendedSignatures[idx].Extension) > 0 || len(ec.ExtendedSignatures[idx].ExtensionSignature) > 0 { + stripped = true + } + ec.ExtendedSignatures[idx].Extension = nil + ec.ExtendedSignatures[idx].ExtensionSignature = nil + } + return stripped +} + +// ToCommit converts an ExtendedCommit to a Commit by removing all vote // extension-related fields. -func (ec *ExtendedCommit) StripExtensions() *Commit { +func (ec *ExtendedCommit) ToCommit() *Commit { cs := make([]CommitSig, len(ec.ExtendedSignatures)) for idx, ecs := range ec.ExtendedSignatures { cs[idx] = ecs.CommitSig diff --git a/types/block_test.go b/types/block_test.go index bd0526c39..b72ea083d 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -3,6 +3,7 @@ package types import ( // it is ok to use math/rand here: we do not need a cryptographically secure random // number generator here and we can run the tests a bit faster + "crypto/rand" "encoding/hex" "math" @@ -45,7 +46,7 @@ func TestBlockAddEvidence(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - block := MakeBlock(h, txs, extCommit.StripExtensions(), evList) + block := MakeBlock(h, txs, extCommit.ToCommit(), evList) require.NotNil(t, block) require.Equal(t, 1, len(block.Evidence.Evidence)) require.NotNil(t, block.EvidenceHash) @@ -61,7 +62,7 @@ func TestBlockValidateBasic(t *testing.T) { voteSet, valSet, vals := randVoteSet(h-1, 1, tmproto.PrecommitType, 10, 1) extCommit, err := MakeExtCommit(lastID, h-1, 1, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() ev, err := NewMockDuplicateVoteEvidenceWithValidator(h, time.Now(), vals[0], "block-test-chain") require.NoError(t, err) @@ -139,7 +140,7 @@ func TestBlockMakePartSetWithEvidence(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - partSet, err := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.StripExtensions(), evList).MakePartSet(512) + partSet, err := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.ToCommit(), evList).MakePartSet(512) require.NoError(t, err) assert.NotNil(t, partSet) @@ -159,7 +160,7 @@ func TestBlockHashesTo(t *testing.T) { require.NoError(t, err) evList := []Evidence{ev} - block := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.StripExtensions(), evList) + block := MakeBlock(h, []Tx{Tx("Hello World")}, extCommit.ToCommit(), evList) block.ValidatorsHash = valSet.Hash() assert.False(t, block.HashesTo([]byte{})) assert.False(t, block.HashesTo([]byte("something else"))) @@ -443,7 +444,7 @@ func randCommit(now time.Time) *Commit { if err != nil { panic(err) } - return commit.StripExtensions() + return commit.ToCommit() } func hexBytesFromString(s string) bytes.HexBytes { @@ -515,30 +516,133 @@ func TestBlockMaxDataBytesNoEvidence(t *testing.T) { } } +// TestVoteSetToExtendedCommit tests that the extended commit produced from a +// vote set contains the same vote information as the vote set. The test ensures +// that the MakeExtendedCommit method behaves as expected, whether vote extensions +// are present in the original votes or not. +func TestVoteSetToExtendedCommit(t *testing.T) { + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + + t.Run(testCase.name, func(t *testing.T) { + blockID := makeBlockIDRandom() + + valSet, vals := RandValidatorSet(10, 1) + var voteSet *VoteSet + if testCase.includeExtension { + voteSet = NewExtendedVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", 3, 1, tmproto.PrecommitType, valSet) + } + for i := 0; i < len(vals); i++ { + pubKey, err := vals[i].GetPubKey() + require.NoError(t, err) + vote := &Vote{ + ValidatorAddress: pubKey.Address(), + ValidatorIndex: int32(i), + Height: 3, + Round: 1, + Type: tmproto.PrecommitType, + BlockID: blockID, + Timestamp: time.Now(), + } + v := vote.ToProto() + err = vals[i].SignVote(voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + if testCase.includeExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + added, err := voteSet.AddVote(vote) + require.NoError(t, err) + require.True(t, added) + } + ec := voteSet.MakeExtendedCommit() + + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := ec.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + } + }) + } +} + +// TestExtendedCommitToVoteSet tests that the vote set produced from an extended commit +// contains the same vote information as the extended commit. The test ensures +// that the ToVoteSet method behaves as expected, whether vote extensions +// are present in the original votes or not. func TestExtendedCommitToVoteSet(t *testing.T) { - lastID := makeBlockIDRandom() - h := int64(3) + for _, testCase := range []struct { + name string + includeExtension bool + }{ + { + name: "no extensions", + includeExtension: false, + }, + { + name: "with extensions", + includeExtension: true, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + lastID := makeBlockIDRandom() + h := int64(3) - voteSet, valSet, vals := randVoteSet(h-1, 1, tmproto.PrecommitType, 10, 1) - extCommit, err := MakeExtCommit(lastID, h-1, 1, voteSet, vals, time.Now()) - assert.NoError(t, err) + voteSet, valSet, vals := randVoteSet(h-1, 1, tmproto.PrecommitType, 10, 1) + extCommit, err := MakeExtCommit(lastID, h-1, 1, voteSet, vals, time.Now()) + assert.NoError(t, err) - chainID := voteSet.ChainID() - voteSet2 := extCommit.ToVoteSet(chainID, valSet) + if !testCase.includeExtension { + for i := 0; i < len(vals); i++ { + v := voteSet.GetByIndex(int32(i)) + v.Extension = nil + v.ExtensionSignature = nil + extCommit.ExtendedSignatures[i].Extension = nil + extCommit.ExtendedSignatures[i].ExtensionSignature = nil + } + } - for i := int32(0); int(i) < len(vals); i++ { - vote1 := voteSet.GetByIndex(i) - vote2 := voteSet2.GetByIndex(i) - vote3 := extCommit.GetExtendedVote(i) + chainID := voteSet.ChainID() + var voteSet2 *VoteSet + if testCase.includeExtension { + voteSet2 = extCommit.ToExtendedVoteSet(chainID, valSet) + } else { + voteSet2 = extCommit.ToVoteSet(chainID, valSet) + } - vote1bz, err := vote1.ToProto().Marshal() - require.NoError(t, err) - vote2bz, err := vote2.ToProto().Marshal() - require.NoError(t, err) - vote3bz, err := vote3.ToProto().Marshal() - require.NoError(t, err) - assert.Equal(t, vote1bz, vote2bz) - assert.Equal(t, vote1bz, vote3bz) + for i := int32(0); int(i) < len(vals); i++ { + vote1 := voteSet.GetByIndex(i) + vote2 := voteSet2.GetByIndex(i) + vote3 := extCommit.GetExtendedVote(i) + + vote1bz, err := vote1.ToProto().Marshal() + require.NoError(t, err) + vote2bz, err := vote2.ToProto().Marshal() + require.NoError(t, err) + vote3bz, err := vote3.ToProto().Marshal() + require.NoError(t, err) + assert.Equal(t, vote1bz, vote2bz) + assert.Equal(t, vote1bz, vote3bz) + } + }) } } @@ -590,7 +694,7 @@ func TestCommitToVoteSetWithVotesForNilBlock(t *testing.T) { if tc.valid { extCommit := voteSet.MakeExtendedCommit() // panics without > 2/3 valid votes assert.NotNil(t, extCommit) - err := valSet.VerifyCommit(voteSet.ChainID(), blockID, height-1, extCommit.StripExtensions()) + err := valSet.VerifyCommit(voteSet.ChainID(), blockID, height-1, extCommit.ToCommit()) assert.Nil(t, err) } else { assert.Panics(t, func() { voteSet.MakeExtendedCommit() }) diff --git a/types/evidence_test.go b/types/evidence_test.go index 363ebb11e..0b513116b 100644 --- a/types/evidence_test.go +++ b/types/evidence_test.go @@ -101,7 +101,7 @@ func TestLightClientAttackEvidenceBasic(t *testing.T) { blockID := makeBlockID(tmhash.Sum([]byte("blockhash")), math.MaxInt32, tmhash.Sum([]byte("partshash"))) extCommit, err := MakeExtCommit(blockID, height, 1, voteSet, privVals, defaultVoteTime) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() lcae := &LightClientAttackEvidence{ ConflictingBlock: &LightBlock{ @@ -163,7 +163,7 @@ func TestLightClientAttackEvidenceValidation(t *testing.T) { blockID := makeBlockID(header.Hash(), math.MaxInt32, tmhash.Sum([]byte("partshash"))) extCommit, err := MakeExtCommit(blockID, height, 1, voteSet, privVals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() lcae := &LightClientAttackEvidence{ ConflictingBlock: &LightBlock{ diff --git a/types/params.go b/types/params.go index 246037d85..7dd032060 100644 --- a/types/params.go +++ b/types/params.go @@ -37,6 +37,7 @@ type ConsensusParams struct { Evidence EvidenceParams `json:"evidence"` Validator ValidatorParams `json:"validator"` Version VersionParams `json:"version"` + ABCI ABCIParams `json:"abci"` } // BlockParams define limits on the block size and gas plus minimum time @@ -63,6 +64,21 @@ type VersionParams struct { App uint64 `json:"app"` } +// ABCIParams configure ABCI functionality specific to the Application Blockchain +// Interface. +type ABCIParams struct { + VoteExtensionsEnableHeight int64 `json:"vote_extensions_enable_height"` +} + +// VoteExtensionsEnabled returns true if vote extensions are enabled at height h +// and false otherwise. +func (a ABCIParams) VoteExtensionsEnabled(h int64) bool { + if a.VoteExtensionsEnableHeight == 0 { + return false + } + return a.VoteExtensionsEnableHeight <= h +} + // DefaultConsensusParams returns a default ConsensusParams. func DefaultConsensusParams() *ConsensusParams { return &ConsensusParams{ @@ -70,6 +86,7 @@ func DefaultConsensusParams() *ConsensusParams { Evidence: DefaultEvidenceParams(), Validator: DefaultValidatorParams(), Version: DefaultVersionParams(), + ABCI: DefaultABCIParams(), } } @@ -104,6 +121,13 @@ func DefaultVersionParams() VersionParams { } } +func DefaultABCIParams() ABCIParams { + return ABCIParams{ + // When set to 0, vote extensions are not required. + VoteExtensionsEnableHeight: 0, + } +} + func IsValidPubkeyType(params ValidatorParams, pubkeyType string) bool { for i := 0; i < len(params.PubKeyTypes); i++ { if params.PubKeyTypes[i] == pubkeyType { diff --git a/types/validation_test.go b/types/validation_test.go index f1f349a78..c720bf23c 100644 --- a/types/validation_test.go +++ b/types/validation_test.go @@ -145,7 +145,7 @@ func TestValidatorSet_VerifyCommit_CheckAllSignatures(t *testing.T) { voteSet, valSet, vals := randVoteSet(h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := MakeExtCommit(blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) // malleate 4th signature @@ -173,7 +173,7 @@ func TestValidatorSet_VerifyCommitLight_ReturnsAsSoonAsMajorityOfVotingPowerSign voteSet, valSet, vals := randVoteSet(h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := MakeExtCommit(blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) // malleate 4th signature (3 signatures are enough for 2/3+) @@ -199,7 +199,7 @@ func TestValidatorSet_VerifyCommitLightTrusting_ReturnsAsSoonAsTrustLevelOfVotin voteSet, valSet, vals := randVoteSet(h, 0, tmproto.PrecommitType, 4, 10) extCommit, err := MakeExtCommit(blockID, h, 0, voteSet, vals, time.Now()) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() require.NoError(t, valSet.VerifyCommit(chainID, blockID, h, commit)) // malleate 3rd signature (2 signatures are enough for 1/3+ trust level) @@ -223,7 +223,7 @@ func TestValidatorSet_VerifyCommitLightTrusting(t *testing.T) { newValSet, _ = RandValidatorSet(2, 1) ) require.NoError(t, err) - commit := extCommit.StripExtensions() + commit := extCommit.ToCommit() testCases := []struct { valSet *ValidatorSet @@ -265,7 +265,7 @@ func TestValidatorSet_VerifyCommitLightTrustingErrorsOnOverflow(t *testing.T) { ) require.NoError(t, err) - err = valSet.VerifyCommitLightTrusting("test_chain_id", extCommit.StripExtensions(), + err = valSet.VerifyCommitLightTrusting("test_chain_id", extCommit.ToCommit(), tmmath.Fraction{Numerator: 25, Denominator: 55}) if assert.Error(t, err) { assert.Contains(t, err.Error(), "int64 overflow") diff --git a/types/vote.go b/types/vote.go index 13053f14f..e951f091b 100644 --- a/types/vote.go +++ b/types/vote.go @@ -27,7 +27,7 @@ var ( ErrVoteInvalidBlockHash = errors.New("invalid block hash") ErrVoteNonDeterministicSignature = errors.New("non-deterministic signature") ErrVoteNil = errors.New("nil vote") - ErrVoteInvalidExtension = errors.New("invalid vote extension") + ErrVoteExtensionAbsent = errors.New("vote extension absent") ) type ErrVoteConflictingVotes struct { @@ -112,6 +112,16 @@ func (vote *Vote) CommitSig() CommitSig { } } +// StripExtension removes any extension data from the vote. Useful if the +// chain has not enabled vote extensions. +// Returns true if extension data was present before stripping and false otherwise. +func (vote *Vote) StripExtension() bool { + stripped := len(vote.Extension) > 0 || len(vote.ExtensionSignature) > 0 + vote.Extension = nil + vote.ExtensionSignature = nil + return stripped +} + // ExtendedCommitSig attempts to construct an ExtendedCommitSig from this vote. // Panics if either the vote extension signature is missing or if the block ID // is not either empty or complete. @@ -120,13 +130,8 @@ func (vote *Vote) ExtendedCommitSig() ExtendedCommitSig { return NewExtendedCommitSigAbsent() } - cs := vote.CommitSig() - if vote.BlockID.IsComplete() && len(vote.ExtensionSignature) == 0 { - panic(fmt.Sprintf("Invalid vote %v - BlockID is complete but missing vote extension signature", vote)) - } - return ExtendedCommitSig{ - CommitSig: cs, + CommitSig: vote.CommitSig(), Extension: vote.Extension, ExtensionSignature: vote.ExtensionSignature, } @@ -230,11 +235,11 @@ func (vote *Vote) Verify(chainID string, pubKey crypto.PubKey) error { return err } -// VerifyWithExtension performs the same verification as Verify, but +// VerifyVoteAndExtension performs the same verification as Verify, but // additionally checks whether the vote extension signature corresponds to the // given chain ID and public key. We only verify vote extension signatures for // precommits. -func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) error { +func (vote *Vote) VerifyVoteAndExtension(chainID string, pubKey crypto.PubKey) error { v, err := vote.verifyAndReturnProto(chainID, pubKey) if err != nil { return err @@ -249,6 +254,20 @@ func (vote *Vote) VerifyWithExtension(chainID string, pubKey crypto.PubKey) erro return nil } +// VerifyExtension checks whether the vote extension signature corresponds to the +// given chain ID and public key. +func (vote *Vote) VerifyExtension(chainID string, pubKey crypto.PubKey) error { + if vote.Type != tmproto.PrecommitType || len(vote.BlockID.Hash) == 0 { + return nil + } + v := vote.ToProto() + extSignBytes := VoteExtensionSignBytes(chainID, v) + if !pubKey.VerifySignature(extSignBytes, vote.ExtensionSignature) { + return ErrVoteInvalidSignature + } + return nil +} + // ValidateBasic checks whether the vote is well-formed. It does not, however, // check vote extensions - for vote validation with vote extension validation, // use ValidateWithExtension. @@ -306,30 +325,34 @@ func (vote *Vote) ValidateBasic() error { } } - return nil -} - -// ValidateWithExtension performs the same validations as ValidateBasic, but -// additionally checks whether a vote extension signature is present. This -// function is used in places where vote extension signatures are expected. -func (vote *Vote) ValidateWithExtension() error { - if err := vote.ValidateBasic(); err != nil { - return err - } - - // We should always see vote extension signatures in non-nil precommits if vote.Type == tmproto.PrecommitType && len(vote.BlockID.Hash) != 0 { - if len(vote.ExtensionSignature) == 0 { - return errors.New("vote extension signature is missing") - } if len(vote.ExtensionSignature) > MaxSignatureSize { return fmt.Errorf("vote extension signature is too big (max: %d)", MaxSignatureSize) } + if len(vote.ExtensionSignature) == 0 && len(vote.Extension) != 0 { + return fmt.Errorf("vote extension signature absent on vote with extension") + } } return nil } +// EnsureExtension checks for the presence of extensions signature data +// on precommit vote types. +func (vote *Vote) EnsureExtension() error { + // We should always see vote extension signatures in non-nil precommits + if vote.Type != tmproto.PrecommitType { + return nil + } + if len(vote.BlockID.Hash) == 0 { + return nil + } + if len(vote.ExtensionSignature) > 0 { + return nil + } + return ErrVoteExtensionAbsent +} + // ToProto converts the handwritten type to proto generated type // return type, nil if everything converts safely, otherwise nil, error func (vote *Vote) ToProto() *tmproto.Vote { diff --git a/types/vote_set.go b/types/vote_set.go index efbbaa57e..9ed6df9e2 100644 --- a/types/vote_set.go +++ b/types/vote_set.go @@ -2,6 +2,7 @@ package types import ( "bytes" + "errors" "fmt" "strings" @@ -59,11 +60,12 @@ there's only a limited number of peers. NOTE: Assumes that the sum total of voting power does not exceed MaxUInt64. */ type VoteSet struct { - chainID string - height int64 - round int32 - signedMsgType tmproto.SignedMsgType - valSet *ValidatorSet + chainID string + height int64 + round int32 + signedMsgType tmproto.SignedMsgType + valSet *ValidatorSet + extensionsEnabled bool mtx tmsync.Mutex votesBitArray *bits.BitArray @@ -74,7 +76,8 @@ type VoteSet struct { peerMaj23s map[P2PID]BlockID // Maj23 for each peer } -// Constructs a new VoteSet struct used to accumulate votes for given height/round. +// NewVoteSet instantiates all fields of a new vote set. This constructor requires +// that no vote extension data be present on the votes that are added to the set. func NewVoteSet(chainID string, height int64, round int32, signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { if height == 0 { @@ -95,6 +98,16 @@ func NewVoteSet(chainID string, height int64, round int32, } } +// NewExtendedVoteSet constructs a vote set with additional vote verification logic. +// The VoteSet constructed with NewExtendedVoteSet verifies the vote extension +// data for every vote added to the set. +func NewExtendedVoteSet(chainID string, height int64, round int32, + signedMsgType tmproto.SignedMsgType, valSet *ValidatorSet) *VoteSet { + vs := NewVoteSet(chainID, height, round, signedMsgType, valSet) + vs.extensionsEnabled = true + return vs +} + func (voteSet *VoteSet) ChainID() string { return voteSet.chainID } @@ -202,8 +215,17 @@ func (voteSet *VoteSet) addVote(vote *Vote) (added bool, err error) { } // Check signature. - if err := vote.VerifyWithExtension(voteSet.chainID, val.PubKey); err != nil { - return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + if voteSet.extensionsEnabled { + if err := vote.VerifyVoteAndExtension(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } + } else { + if err := vote.Verify(voteSet.chainID, val.PubKey); err != nil { + return false, fmt.Errorf("failed to verify vote with ChainID %s and PubKey %s: %w", voteSet.chainID, val.PubKey, err) + } + if len(vote.ExtensionSignature) > 0 || len(vote.Extension) > 0 { + return false, errors.New("unexpected vote extension data present in vote") + } } // Add vote and get conflicting vote if any. diff --git a/types/vote_set_test.go b/types/vote_set_test.go index 2246bb57a..d6862883d 100644 --- a/types/vote_set_test.go +++ b/types/vote_set_test.go @@ -475,6 +475,89 @@ func TestVoteSet_MakeCommit(t *testing.T) { } } +// TestVoteSet_VoteExtensionsEnabled tests that the vote set correctly validates +// vote extensions data when either required or not required. +func TestVoteSet_VoteExtensionsEnabled(t *testing.T) { + for _, tc := range []struct { + name string + requireExtensions bool + addExtension bool + exepectError bool + }{ + { + name: "no extension but expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "invalid extensions but not expected", + requireExtensions: true, + addExtension: false, + exepectError: true, + }, + { + name: "no extension and not expected", + requireExtensions: false, + addExtension: false, + exepectError: false, + }, + { + name: "extension and expected", + requireExtensions: true, + addExtension: true, + exepectError: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + height, round := int64(1), int32(0) + valSet, privValidators := RandValidatorSet(5, 10) + var voteSet *VoteSet + if tc.requireExtensions { + voteSet = NewExtendedVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } else { + voteSet = NewVoteSet("test_chain_id", height, round, tmproto.PrecommitType, valSet) + } + + val0 := privValidators[0] + + val0p, err := val0.GetPubKey() + require.NoError(t, err) + val0Addr := val0p.Address() + blockHash := crypto.CRandBytes(32) + blockPartsTotal := uint32(123) + blockPartSetHeader := PartSetHeader{blockPartsTotal, crypto.CRandBytes(32)} + + vote := &Vote{ + ValidatorAddress: val0Addr, + ValidatorIndex: 0, + Height: height, + Round: round, + Type: tmproto.PrecommitType, + Timestamp: tmtime.Now(), + BlockID: BlockID{blockHash, blockPartSetHeader}, + } + v := vote.ToProto() + err = val0.SignVote(voteSet.ChainID(), v) + require.NoError(t, err) + vote.Signature = v.Signature + + if tc.addExtension { + vote.ExtensionSignature = v.ExtensionSignature + } + + added, err := voteSet.AddVote(vote) + if tc.exepectError { + require.Error(t, err) + require.False(t, added) + } else { + require.NoError(t, err) + require.True(t, added) + } + }) + } +} + // NOTE: privValidators are in order func randVoteSet( height int64, @@ -484,7 +567,7 @@ func randVoteSet( votingPower int64, ) (*VoteSet, *ValidatorSet, []PrivValidator) { valSet, privValidators := RandValidatorSet(numValidators, votingPower) - return NewVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators + return NewExtendedVoteSet("test_chain_id", height, round, signedMsgType, valSet), valSet, privValidators } // Convenience: Return new vote with different validator address/index diff --git a/types/vote_test.go b/types/vote_test.go index 7de1932db..b28ffa955 100644 --- a/types/vote_test.go +++ b/types/vote_test.go @@ -240,9 +240,6 @@ func TestVoteExtension(t *testing.T) { privVal := NewMockPV() pk, err := privVal.GetPubKey() require.NoError(t, err) - blk := Block{} - ps, err := blk.MakePartSet(BlockPartSizeBytes) - require.NoError(t, err) vote := &Vote{ ValidatorAddress: pk.Address(), ValidatorIndex: 0, @@ -250,7 +247,7 @@ func TestVoteExtension(t *testing.T) { Round: round, Timestamp: tmtime.Now(), Type: tmproto.PrecommitType, - BlockID: BlockID{blk.Hash(), ps.Header()}, + BlockID: makeBlockIDRandom(), } v := vote.ToProto() @@ -260,7 +257,7 @@ func TestVoteExtension(t *testing.T) { if tc.includeSignature { vote.ExtensionSignature = v.ExtensionSignature } - err = vote.VerifyWithExtension("test_chain_id", pk) + err = vote.VerifyExtension("test_chain_id", pk) if tc.expectError { require.Error(t, err) } else { @@ -349,7 +346,7 @@ func TestValidVotes(t *testing.T) { signVote(t, privVal, "test_chain_id", tc.vote) tc.malleateVote(tc.vote) require.NoError(t, tc.vote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.NoError(t, tc.vote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, tc.vote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -373,13 +370,13 @@ func TestInvalidVotes(t *testing.T) { signVote(t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s in invalid prevote", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s in invalid prevote", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s in invalid prevote", tc.name) precommit := examplePrecommit() signVote(t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s in invalid precommit", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s in invalid precommit", tc.name) + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s in invalid precommit", tc.name) } } @@ -398,7 +395,7 @@ func TestInvalidPrevotes(t *testing.T) { signVote(t, privVal, "test_chain_id", prevote) tc.malleateVote(prevote) require.Error(t, prevote.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, prevote.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + require.NoError(t, prevote.EnsureExtension(), "EnsureExtension for %s", tc.name) } } @@ -413,18 +410,42 @@ func TestInvalidPrecommitExtensions(t *testing.T) { v.Extension = []byte("extension") v.ExtensionSignature = nil }}, - // TODO(thane): Re-enable once https://github.com/tendermint/tendermint/issues/8272 is resolved - //{"missing vote extension signature", func(v *Vote) { v.ExtensionSignature = nil }}, {"oversized vote extension signature", func(v *Vote) { v.ExtensionSignature = make([]byte, MaxSignatureSize+1) }}, } for _, tc := range testCases { precommit := examplePrecommit() signVote(t, privVal, "test_chain_id", precommit) tc.malleateVote(precommit) - // We don't expect an error from ValidateBasic, because it doesn't - // handle vote extensions. - require.NoError(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) - require.Error(t, precommit.ValidateWithExtension(), "ValidateWithExtension for %s", tc.name) + // ValidateBasic ensures that vote extensions, if present, are well formed + require.Error(t, precommit.ValidateBasic(), "ValidateBasic for %s", tc.name) + } +} + +func TestEnsureVoteExtension(t *testing.T) { + privVal := NewMockPV() + + testCases := []struct { + name string + malleateVote func(*Vote) + expectError bool + }{ + {"vote extension signature absent", func(v *Vote) { + v.Extension = nil + v.ExtensionSignature = nil + }, true}, + {"vote extension signature present", func(v *Vote) { + v.ExtensionSignature = []byte("extension signature") + }, false}, + } + for _, tc := range testCases { + precommit := examplePrecommit() + signVote(t, privVal, "test_chain_id", precommit) + tc.malleateVote(precommit) + if tc.expectError { + require.Error(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } else { + require.NoError(t, precommit.EnsureExtension(), "EnsureExtension for %s", tc.name) + } } }