diff --git a/abci/client/client.go b/abci/client/client.go index a38c7f81b..9725f8838 100644 --- a/abci/client/client.go +++ b/abci/client/client.go @@ -87,9 +87,15 @@ type ReqRes struct { *sync.WaitGroup *types.Response // Not set atomically, so be sure to use WaitGroup. - mtx tmsync.Mutex - done bool // Gets set to true once *after* WaitGroup.Done(). - cb func(*types.Response) // A single callback that may be set. + mtx tmsync.Mutex + + // callbackInvoked as a variable to track if the callback was already + // invoked during the regular execution of the request. This variable + // allows clients to set the callback simultaneously without potentially + // invoking the callback twice by accident, once when 'SetCallback' is + // called and once during the normal request. + callbackInvoked bool + cb func(*types.Response) // A single callback that may be set. } func NewReqRes(req *types.Request) *ReqRes { @@ -98,8 +104,8 @@ func NewReqRes(req *types.Request) *ReqRes { WaitGroup: waitGroup1(), Response: nil, - done: false, - cb: nil, + callbackInvoked: false, + cb: nil, } } @@ -109,7 +115,7 @@ func NewReqRes(req *types.Request) *ReqRes { func (r *ReqRes) SetCallback(cb func(res *types.Response)) { r.mtx.Lock() - if r.done { + if r.callbackInvoked { r.mtx.Unlock() cb(r.Response) return @@ -128,6 +134,7 @@ func (r *ReqRes) InvokeCallback() { if r.cb != nil { r.cb(r.Response) } + r.callbackInvoked = true } // GetCallback returns the configured callback of the ReqRes object which may be @@ -142,13 +149,6 @@ func (r *ReqRes) GetCallback() func(*types.Response) { return r.cb } -// SetDone marks the ReqRes object as done. -func (r *ReqRes) SetDone() { - r.mtx.Lock() - r.done = true - r.mtx.Unlock() -} - func waitGroup1() (wg *sync.WaitGroup) { wg = &sync.WaitGroup{} wg.Add(1) diff --git a/abci/client/grpc_client.go b/abci/client/grpc_client.go index 990aa5bb4..049910bea 100644 --- a/abci/client/grpc_client.go +++ b/abci/client/grpc_client.go @@ -72,7 +72,6 @@ func (cli *grpcClient) OnStart() error { cli.mtx.Lock() defer cli.mtx.Unlock() - reqres.SetDone() reqres.Done() // Notify client listener if set @@ -81,9 +80,7 @@ func (cli *grpcClient) OnStart() error { } // Notify reqRes listener if set - if cb := reqres.GetCallback(); cb != nil { - cb(reqres.Response) - } + reqres.InvokeCallback() } for reqres := range cli.chReqRes { if reqres != nil { diff --git a/abci/client/local_client.go b/abci/client/local_client.go index 701108a3c..33773e936 100644 --- a/abci/client/local_client.go +++ b/abci/client/local_client.go @@ -348,12 +348,13 @@ func (app *localClient) ApplySnapshotChunkSync( func (app *localClient) callback(req *types.Request, res *types.Response) *ReqRes { app.Callback(req, res) - return newLocalReqRes(req, res) + rr := newLocalReqRes(req, res) + rr.callbackInvoked = true + return rr } func newLocalReqRes(req *types.Request, res *types.Response) *ReqRes { reqRes := NewReqRes(req) reqRes.Response = res - reqRes.SetDone() return reqRes } diff --git a/abci/client/socket_client_test.go b/abci/client/socket_client_test.go index 53ba7b672..e01f60f92 100644 --- a/abci/client/socket_client_test.go +++ b/abci/client/socket_client_test.go @@ -3,6 +3,7 @@ package abciclient_test import ( "context" "fmt" + "sync" "testing" "time" @@ -125,3 +126,73 @@ func (slowApp) BeginBlock(req types.RequestBeginBlock) types.ResponseBeginBlock time.Sleep(200 * time.Millisecond) return types.ResponseBeginBlock{} } + +// TestCallbackInvokedWhenSetLaet ensures that the callback is invoked when +// set after the client completes the call into the app. Currently this +// test relies on the callback being allowed to be invoked twice if set multiple +// times, once when set early and once when set late. +func TestCallbackInvokedWhenSetLate(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + app := blockedABCIApplication{ + wg: wg, + } + _, c := setupClientServer(t, app) + reqRes, err := c.CheckTxAsync(context.Background(), types.RequestCheckTx{}) + require.NoError(t, err) + + done := make(chan struct{}) + cb := func(_ *types.Response) { + close(done) + } + reqRes.SetCallback(cb) + app.wg.Done() + <-done + + var called bool + cb = func(_ *types.Response) { + called = true + } + reqRes.SetCallback(cb) + require.True(t, called) +} + +type blockedABCIApplication struct { + wg *sync.WaitGroup + types.BaseApplication +} + +func (b blockedABCIApplication) CheckTx(r types.RequestCheckTx) types.ResponseCheckTx { + b.wg.Wait() + return b.BaseApplication.CheckTx(r) +} + +// TestCallbackInvokedWhenSetEarly ensures that the callback is invoked when +// set before the client completes the call into the app. +func TestCallbackInvokedWhenSetEarly(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + app := blockedABCIApplication{ + wg: wg, + } + _, c := setupClientServer(t, app) + reqRes, err := c.CheckTxAsync(context.Background(), types.RequestCheckTx{}) + require.NoError(t, err) + + done := make(chan struct{}) + cb := func(_ *types.Response) { + close(done) + } + reqRes.SetCallback(cb) + app.wg.Done() + + called := func() bool { + select { + case <-done: + return true + default: + return false + } + } + require.Eventually(t, called, time.Second, time.Millisecond*25) +} diff --git a/internal/consensus/common_test.go b/internal/consensus/common_test.go index c3f93d8c1..0477c8b0c 100644 --- a/internal/consensus/common_test.go +++ b/internal/consensus/common_test.go @@ -663,6 +663,39 @@ func ensurePrevote(voteCh <-chan tmpubsub.Message, height int64, round int32) { ensureVote(voteCh, height, round, tmproto.PrevoteType) } +func ensurePrevoteMatch(t *testing.T, voteCh <-chan tmpubsub.Message, height int64, round int32, hash []byte) { + t.Helper() + ensureVoteMatch(t, voteCh, height, round, hash, tmproto.PrevoteType) +} + +func ensurePrecommitMatch(t *testing.T, voteCh <-chan tmpubsub.Message, height int64, round int32, hash []byte) { + t.Helper() + ensureVoteMatch(t, voteCh, height, round, hash, tmproto.PrecommitType) +} + +func ensureVoteMatch(t *testing.T, voteCh <-chan tmpubsub.Message, height int64, round int32, hash []byte, voteType tmproto.SignedMsgType) { + t.Helper() + select { + case <-time.After(ensureTimeout): + t.Fatal("Timeout expired while waiting for NewVote event") + case msg := <-voteCh: + voteEvent, ok := msg.Data().(types.EventDataVote) + require.True(t, ok, "expected a EventDataVote, got %T. Wrong subscription channel?", + msg.Data()) + + vote := voteEvent.Vote + require.Equal(t, height, vote.Height) + require.Equal(t, round, vote.Round) + + require.Equal(t, voteType, vote.Type) + if hash == nil { + require.Nil(t, vote.BlockID.Hash, "Expected prevote to be for nil, got %X", vote.BlockID.Hash) + } else { + require.True(t, bytes.Equal(vote.BlockID.Hash, hash), "Expected prevote to be for %X, got %X", hash, vote.BlockID.Hash) + } + } +} + func ensureVote(voteCh <-chan tmpubsub.Message, height int64, round int32, voteType tmproto.SignedMsgType) { select { diff --git a/internal/consensus/state_test.go b/internal/consensus/state_test.go index 65fc6318e..0e447855b 100644 --- a/internal/consensus/state_test.go +++ b/internal/consensus/state_test.go @@ -243,8 +243,7 @@ func TestStateBadProposal(t *testing.T) { ensureProposal(proposalCh, height, round, blockID) // wait for prevote - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + ensurePrevoteMatch(t, voteCh, height, round, nil) // add bad prevote from vs2 and wait for it signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) @@ -308,8 +307,7 @@ func TestStateOversizedBlock(t *testing.T) { // and then should send nil prevote and precommit regardless of whether other validators prevote and // precommit on it - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + ensurePrevoteMatch(t, voteCh, height, round, nil) signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) ensurePrecommit(voteCh, height, round) @@ -352,8 +350,7 @@ func TestStateFullRound1(t *testing.T) { ensureNewProposal(propCh, height, round) propBlockHash := cs.GetRoundState().ProposalBlock.Hash() - ensurePrevote(voteCh, height, round) // wait for prevote - validatePrevote(t, cs, round, vss[0], propBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) ensurePrecommit(voteCh, height, round) // wait for precommit @@ -376,8 +373,8 @@ func TestStateFullRoundNil(t *testing.T) { cs.enterPrevote(height, round) cs.startRoutines(4) - ensurePrevote(voteCh, height, round) // prevote - ensurePrecommit(voteCh, height, round) // precommit + ensurePrevoteMatch(t, voteCh, height, round, nil) // prevote + ensurePrecommitMatch(t, voteCh, height, round, nil) // precommit // should prevote and precommit nil validatePrevoteAndPrecommit(t, cs, round, -1, vss[0], nil, nil) @@ -502,10 +499,8 @@ func TestStateLockNoPOL(t *testing.T) { panic("Expected proposal block to be nil") } - // wait to finish prevote - ensurePrevote(voteCh, height, round) - // we should have prevoted our locked block - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + // wait to finish prevote and ensure we have prevoted our locked block + ensurePrevoteMatch(t, voteCh, height, round, rs.LockedBlock.Hash()) // add a conflicting prevote from the other validator signAddVotes(config, cs1, tmproto.PrevoteType, hash, rs.LockedBlock.MakePartSet(partSize).Header(), vs2) @@ -548,8 +543,7 @@ func TestStateLockNoPOL(t *testing.T) { rs.LockedBlock)) } - ensurePrevote(voteCh, height, round) // prevote - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + ensurePrevoteMatch(t, voteCh, height, round, rs.LockedBlock.Hash()) signAddVotes(config, cs1, tmproto.PrevoteType, hash, rs.ProposalBlock.MakePartSet(partSize).Header(), vs2) ensurePrevote(voteCh, height, round) @@ -594,9 +588,8 @@ func TestStateLockNoPOL(t *testing.T) { } ensureNewProposal(proposalCh, height, round) - ensurePrevote(voteCh, height, round) // prevote // prevote for locked block (not proposal) - validatePrevote(t, cs1, 3, vss[0], cs1.LockedBlock.Hash()) + ensurePrevoteMatch(t, voteCh, height, round, cs1.LockedBlock.Hash()) // prevote for proposed block signAddVotes(config, cs1, tmproto.PrevoteType, propBlock.Hash(), propBlock.MakePartSet(partSize).Header(), vs2) @@ -704,8 +697,7 @@ func TestStateLockPOLRelock(t *testing.T) { ensureNewProposal(proposalCh, height, round) // go to prevote, node should prevote for locked block (not the new proposal) - this is relocking - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, theBlockHash) // now lets add prevotes from everyone else for the new block signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) @@ -757,8 +749,7 @@ func TestStateLockPOLUnlock(t *testing.T) { theBlockHash := rs.ProposalBlock.Hash() theBlockParts := rs.ProposalBlockParts.Header() - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, theBlockHash) signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) @@ -796,8 +787,7 @@ func TestStateLockPOLUnlock(t *testing.T) { ensureNewProposal(proposalCh, height, round) // go to prevote, prevote for locked block (not proposal) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], lockedBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, lockedBlockHash) // now lets add prevotes from everyone else for nil (a polka!) signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) @@ -888,8 +878,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { // now we're on a new round but v1 misses the proposal // go to prevote, node should prevote for locked block (not the new proposal) - this is relocking - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], firstBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, firstBlockHash) // now lets add prevotes from everyone else for the new block signAddVotes(config, cs1, tmproto.PrevoteType, secondBlockHash, secondBlockParts.Header(), vs2, vs3, vs4) @@ -933,9 +922,7 @@ func TestStateLockPOLUnlockOnUnknownBlock(t *testing.T) { t.Fatal(err) } - ensurePrevote(voteCh, height, round) - // we are no longer locked to the first block so we should be able to prevote - validatePrevote(t, cs1, round, vss[0], thirdPropBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, thirdPropBlockHash) signAddVotes(config, cs1, tmproto.PrevoteType, thirdPropBlockHash, thirdPropBlockParts.Header(), vs2, vs3, vs4) @@ -975,8 +962,7 @@ func TestStateLockPOLSafety1(t *testing.T) { rs := cs1.GetRoundState() propBlock := rs.ProposalBlock - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlock.Hash()) + ensurePrevoteMatch(t, voteCh, height, round, propBlock.Hash()) // the others sign a polka but we don't see it prevotes := signVotes(config, tmproto.PrevoteType, @@ -1022,8 +1008,7 @@ func TestStateLockPOLSafety1(t *testing.T) { t.Logf("new prop hash %v", fmt.Sprintf("%X", propBlockHash)) // go to prevote, prevote for proposal block - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) // now we see the others prevote for it, so we should lock on it signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2, vs3, vs4) @@ -1049,10 +1034,8 @@ func TestStateLockPOLSafety1(t *testing.T) { // timeout of propose ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) - // finish prevote - ensurePrevote(voteCh, height, round) - // we should prevote what we're locked on - validatePrevote(t, cs1, round, vss[0], propBlockHash) + // finish prevote and vote for the block we're locked on + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) newStepCh := subscribe(cs1.eventBus, types.EventQueryNewRoundStep) @@ -1119,8 +1102,7 @@ func TestStateLockPOLSafety2(t *testing.T) { } ensureNewProposal(proposalCh, height, round) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash1) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash1) signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash1, propBlockParts1.Header(), vs2, vs3, vs4) @@ -1162,9 +1144,7 @@ func TestStateLockPOLSafety2(t *testing.T) { ensureNewProposal(proposalCh, height, round) ensureNoNewUnlock(unlockCh) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash1) - + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash1) } // 4 vals. @@ -1201,8 +1181,7 @@ func TestProposeValidBlock(t *testing.T) { propBlock := rs.ProposalBlock propBlockHash := propBlock.Hash() - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) // the others sign a polka signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlock.MakePartSet(partSize).Header(), vs2, vs3, vs4) @@ -1225,8 +1204,7 @@ func TestProposeValidBlock(t *testing.T) { // timeout of propose ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) signAddVotes(config, cs1, tmproto.PrevoteType, nil, types.PartSetHeader{}, vs2, vs3, vs4) @@ -1294,8 +1272,7 @@ func TestSetValidBlockOnDelayedPrevote(t *testing.T) { propBlockHash := propBlock.Hash() propBlockParts := propBlock.MakePartSet(partSize) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], propBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, propBlockHash) // vs2 send prevote for propBlock signAddVotes(config, cs1, tmproto.PrevoteType, propBlockHash, propBlockParts.Header(), vs2) @@ -1358,8 +1335,7 @@ func TestSetValidBlockOnDelayedProposal(t *testing.T) { ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + ensurePrevoteMatch(t, voteCh, height, round, nil) prop, propBlock := decideProposal(cs1, vs2, vs2.Height, vs2.Round+1) propBlockHash := propBlock.Hash() @@ -1445,8 +1421,7 @@ func TestWaitingTimeoutProposeOnNewRound(t *testing.T) { ensureNewTimeout(timeoutWaitCh, height, round, cs1.config.Propose(round).Nanoseconds()) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + ensurePrevoteMatch(t, voteCh, height, round, nil) } // 4 vals, 3 Precommits for nil from the higher round. @@ -1515,8 +1490,7 @@ func TestWaitTimeoutProposeOnNilPolkaForTheCurrentRound(t *testing.T) { ensureNewTimeout(timeoutProposeCh, height, round, cs1.config.Propose(round).Nanoseconds()) - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], nil) + ensurePrevoteMatch(t, voteCh, height, round, nil) } // What we want: @@ -1645,8 +1619,7 @@ func TestStartNextHeightCorrectlyAfterTimeout(t *testing.T) { theBlockHash := rs.ProposalBlock.Hash() theBlockParts := rs.ProposalBlockParts.Header() - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, theBlockHash) signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) @@ -1708,8 +1681,7 @@ func TestResetTimeoutPrecommitUponNewHeight(t *testing.T) { theBlockHash := rs.ProposalBlock.Hash() theBlockParts := rs.ProposalBlockParts.Header() - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], theBlockHash) + ensurePrevoteMatch(t, voteCh, height, round, theBlockHash) signAddVotes(config, cs1, tmproto.PrevoteType, theBlockHash, theBlockParts, vs2, vs3, vs4) @@ -1881,8 +1853,7 @@ func TestStateHalt1(t *testing.T) { */ // go to prevote, prevote for locked block - ensurePrevote(voteCh, height, round) - validatePrevote(t, cs1, round, vss[0], rs.LockedBlock.Hash()) + ensurePrevoteMatch(t, voteCh, height, round, rs.LockedBlock.Hash()) // now we receive the precommit from the previous round addVotes(cs1, precommit4) diff --git a/internal/mempool/v0/clist_mempool_test.go b/internal/mempool/v0/clist_mempool_test.go index 8e8253fe4..61ec543ef 100644 --- a/internal/mempool/v0/clist_mempool_test.go +++ b/internal/mempool/v0/clist_mempool_test.go @@ -248,13 +248,13 @@ func TestMempoolUpdateDoesNotPanicWhenApplicationMissedTx(t *testing.T) { for _, tx := range txs { reqRes := abciclient.NewReqRes(abci.ToRequestCheckTx(abci.RequestCheckTx{Tx: tx})) reqRes.Response = abci.ToResponseCheckTx(abci.ResponseCheckTx{Code: abci.CodeTypeOK}) - // SetDone allows the ReqRes to process its callback synchronously. - // This simulates the Response being ready for the client immediately. - reqRes.SetDone() mockClient.On("CheckTxAsync", mock.Anything, mock.Anything).Return(reqRes, nil) err := mp.CheckTx(context.Background(), tx, nil, mempool.TxInfo{}) require.NoError(t, err) + + // ensure that the callback that the mempool sets on the ReqRes is run. + reqRes.InvokeCallback() } // Calling update to remove the first transaction from the mempool.