diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index 72930928d..04ea6d1fc 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -10,8 +10,8 @@ import ( "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/light" "github.com/tendermint/tendermint/light/provider" - mockp "github.com/tendermint/tendermint/light/provider/mock" dbs "github.com/tendermint/tendermint/light/store/db" + "github.com/tendermint/tendermint/types" ) // NOTE: block is produced every minute. Make sure the verification time @@ -21,12 +21,50 @@ import ( // or -benchtime 100x. // // Remember that none of these benchmarks account for network latency. -var ( - benchmarkFullNode = mockp.New(genMockNode(chainID, 1000, 100, 1, bTime)) - genesisBlock, _ = benchmarkFullNode.LightBlock(context.Background(), 1) -) +var () + +type providerBenchmarkImpl struct { + currentHeight int64 + blocks map[int64]*types.LightBlock +} + +func newProviderBenchmarkImpl(headers map[int64]*types.SignedHeader, + vals map[int64]*types.ValidatorSet) provider.Provider { + impl := providerBenchmarkImpl{ + blocks: make(map[int64]*types.LightBlock, len(headers)), + } + for height, header := range headers { + if height > impl.currentHeight { + impl.currentHeight = height + } + impl.blocks[height] = &types.LightBlock{ + SignedHeader: header, + ValidatorSet: vals[height], + } + } + return &impl +} + +func (impl *providerBenchmarkImpl) LightBlock(ctx context.Context, height int64) (*types.LightBlock, error) { + if height == 0 { + return impl.blocks[impl.currentHeight], nil + } + lb, ok := impl.blocks[height] + if !ok { + return nil, provider.ErrLightBlockNotFound + } + return lb, nil +} + +func (impl *providerBenchmarkImpl) ReportEvidence(_ context.Context, _ types.Evidence) error { + panic("not implemented") +} func BenchmarkSequence(b *testing.B) { + headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) + genesisBlock, _ := benchmarkFullNode.LightBlock(context.Background(), 1) + c, err := light.NewClient( context.Background(), chainID, @@ -55,6 +93,10 @@ func BenchmarkSequence(b *testing.B) { } func BenchmarkBisection(b *testing.B) { + headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) + genesisBlock, _ := benchmarkFullNode.LightBlock(context.Background(), 1) + c, err := light.NewClient( context.Background(), chainID, @@ -82,7 +124,10 @@ func BenchmarkBisection(b *testing.B) { } func BenchmarkBackwards(b *testing.B) { + headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) trustedBlock, _ := benchmarkFullNode.LightBlock(context.Background(), 0) + c, err := light.NewClient( context.Background(), chainID, diff --git a/light/client_test.go b/light/client_test.go index b826eaead..de18af111 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -3,11 +3,13 @@ package light_test import ( "context" "errors" + "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" @@ -16,7 +18,7 @@ import ( "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/light" "github.com/tendermint/tendermint/light/provider" - mockp "github.com/tendermint/tendermint/light/provider/mock" + provider_mocks "github.com/tendermint/tendermint/light/provider/mocks" dbs "github.com/tendermint/tendermint/light/store/db" "github.com/tendermint/tendermint/types" ) @@ -57,14 +59,9 @@ var ( // last header (3/3 signed) 3: h3, } - l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} - fullNode = mockp.New( - chainID, - headerSet, - valSet, - ) - deadNode = mockp.NewDeadMock(chainID) - largeFullNode = mockp.New(genMockNode(chainID, 10, 3, 0, bTime)) + l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} + l2 = &types.LightBlock{SignedHeader: h2, ValidatorSet: vals} + l3 = &types.LightBlock{SignedHeader: h3, ValidatorSet: vals} ) func TestValidateTrustOptions(t *testing.T) { @@ -113,11 +110,6 @@ func TestValidateTrustOptions(t *testing.T) { } -func TestMock(t *testing.T) { - l, _ := fullNode.LightBlock(ctx, 3) - assert.Equal(t, int64(3), l.Height) -} - func TestClient_SequentialVerification(t *testing.T) { newKeys := genPrivKeys(4) newVals := newKeys.ToValidators(10, 1) @@ -216,28 +208,22 @@ func TestClient_SequentialVerification(t *testing.T) { } for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { + testCase := tc + t.Run(testCase.name, func(t *testing.T) { + mockNode := mockNodeFromHeadersAndVals(testCase.otherHeaders, testCase.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) c, err := light.NewClient( ctx, chainID, trustOptions, - mockp.New( - chainID, - tc.otherHeaders, - tc.vals, - ), - []provider.Provider{mockp.New( - chainID, - tc.otherHeaders, - tc.vals, - )}, + mockNode, + []provider.Provider{mockNode}, dbs.New(dbm.NewMemDB()), light.SequentialVerification(), light.Logger(log.TestingLogger()), ) - if tc.initErr { + if testCase.initErr { require.Error(t, err) return } @@ -245,11 +231,12 @@ func TestClient_SequentialVerification(t *testing.T) { require.NoError(t, err) _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) - if tc.verifyErr { + if testCase.verifyErr { assert.Error(t, err) } else { assert.NoError(t, err) } + mockNode.AssertExpectations(t) }) } } @@ -343,20 +330,14 @@ func TestClient_SkippingVerification(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { + mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) c, err := light.NewClient( ctx, chainID, trustOptions, - mockp.New( - chainID, - tc.otherHeaders, - tc.vals, - ), - []provider.Provider{mockp.New( - chainID, - tc.otherHeaders, - tc.vals, - )}, + mockNode, + []provider.Provider{mockNode}, dbs.New(dbm.NewMemDB()), light.SkippingVerification(light.DefaultTrustLevel), light.Logger(log.TestingLogger()), @@ -382,8 +363,23 @@ func TestClient_SkippingVerification(t *testing.T) { // start from a large light block to make sure that the pivot height doesn't select a height outside // the appropriate range func TestClientLargeBisectionVerification(t *testing.T) { - veryLargeFullNode := mockp.New(genMockNode(chainID, 100, 3, 0, bTime)) - trustedLightBlock, err := veryLargeFullNode.LightBlock(ctx, 5) + numBlocks := int64(300) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, numBlocks, 101, 2, bTime) + + lastBlock := &types.LightBlock{SignedHeader: mockHeaders[numBlocks], ValidatorSet: mockVals[numBlocks]} + mockNode := &provider_mocks.Provider{} + mockNode.On("LightBlock", mock.Anything, numBlocks). + Return(lastBlock, nil) + + mockNode.On("LightBlock", mock.Anything, int64(200)). + Return(&types.LightBlock{SignedHeader: mockHeaders[200], ValidatorSet: mockVals[200]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(256)). + Return(&types.LightBlock{SignedHeader: mockHeaders[256], ValidatorSet: mockVals[256]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) + + trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) require.NoError(t, err) c, err := light.NewClient( ctx, @@ -393,20 +389,25 @@ func TestClientLargeBisectionVerification(t *testing.T) { Height: trustedLightBlock.Height, Hash: trustedLightBlock.Hash(), }, - veryLargeFullNode, - []provider.Provider{veryLargeFullNode}, + mockNode, + []provider.Provider{mockNode}, dbs.New(dbm.NewMemDB()), light.SkippingVerification(light.DefaultTrustLevel), ) require.NoError(t, err) - h, err := c.Update(ctx, bTime.Add(100*time.Minute)) + h, err := c.Update(ctx, bTime.Add(300*time.Second)) assert.NoError(t, err) - h2, err := veryLargeFullNode.LightBlock(ctx, 100) + height, err := c.LastTrustedHeight() + require.NoError(t, err) + require.Equal(t, numBlocks, height) + h2, err := mockNode.LightBlock(ctx, numBlocks) require.NoError(t, err) assert.Equal(t, h, h2) + mockNode.AssertExpectations(t) } func TestClientBisectionBetweenTrustedHeaders(t *testing.T) { + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) c, err := light.NewClient( ctx, chainID, @@ -415,8 +416,8 @@ func TestClientBisectionBetweenTrustedHeaders(t *testing.T) { Height: 1, Hash: h1.Hash(), }, - fullNode, - []provider.Provider{fullNode}, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.SkippingVerification(light.DefaultTrustLevel), ) @@ -432,15 +433,18 @@ func TestClientBisectionBetweenTrustedHeaders(t *testing.T) { // verify using bisection the light block between the two trusted light blocks _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) assert.NoError(t, err) + mockFullNode.AssertExpectations(t) } func TestClient_Cleanup(t *testing.T) { + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{fullNode}, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -455,12 +459,14 @@ func TestClient_Cleanup(t *testing.T) { l, err := c.TrustedLightBlock(1) assert.Error(t, err) assert.Nil(t, l) + mockFullNode.AssertExpectations(t) } // trustedHeader.Height == options.Height func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { // 1. options.Hash == trustedHeader.Hash - { + t.Run("hashes should match", func(t *testing.T) { + mockNode := &provider_mocks.Provider{} trustedStore := dbs.New(dbm.NewMemDB()) err := trustedStore.SaveLightBlock(l1) require.NoError(t, err) @@ -469,8 +475,8 @@ func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { ctx, chainID, trustOptions, - fullNode, - []provider.Provider{fullNode}, + mockNode, + []provider.Provider{mockNode}, trustedStore, light.Logger(log.TestingLogger()), ) @@ -481,10 +487,11 @@ func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { assert.NotNil(t, l) assert.Equal(t, l.Hash(), h1.Hash()) assert.Equal(t, l.ValidatorSet.Hash(), h1.ValidatorsHash.Bytes()) - } + mockNode.AssertExpectations(t) + }) // 2. options.Hash != trustedHeader.Hash - { + t.Run("hashes should not match", func(t *testing.T) { trustedStore := dbs.New(dbm.NewMemDB()) err := trustedStore.SaveLightBlock(l1) require.NoError(t, err) @@ -492,15 +499,7 @@ func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { // header1 != h1 header1 := keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) - - primary := mockp.New( - chainID, - map[int64]*types.SignedHeader{ - // trusted header - 1: header1, - }, - valSet, - ) + mockNode := &provider_mocks.Provider{} c, err := light.NewClient( ctx, @@ -510,8 +509,8 @@ func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { Height: 1, Hash: header1.Hash(), }, - primary, - []provider.Provider{primary}, + mockNode, + []provider.Provider{mockNode}, trustedStore, light.Logger(log.TestingLogger()), ) @@ -524,16 +523,21 @@ func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { assert.Equal(t, l.Hash(), l1.Hash()) assert.NoError(t, l.ValidateBasic(chainID)) } - } + mockNode.AssertExpectations(t) + }) } func TestClient_Update(t *testing.T) { + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(3)).Return(l3, nil) c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{fullNode}, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -546,15 +550,19 @@ func TestClient_Update(t *testing.T) { assert.EqualValues(t, 3, l.Height) assert.NoError(t, l.ValidateBasic(chainID)) } + mockFullNode.AssertExpectations(t) } func TestClient_Concurrency(t *testing.T) { + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(2)).Return(l2, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{fullNode}, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -587,15 +595,21 @@ func TestClient_Concurrency(t *testing.T) { } wg.Wait() + mockFullNode.AssertExpectations(t) } func TestClientReplacesPrimaryWithWitnessIfPrimaryIsUnavailable(t *testing.T) { + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) + + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) c, err := light.NewClient( ctx, chainID, trustOptions, - deadNode, - []provider.Provider{fullNode, fullNode}, + mockDeadNode, + []provider.Provider{mockFullNode, mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -605,16 +619,25 @@ func TestClientReplacesPrimaryWithWitnessIfPrimaryIsUnavailable(t *testing.T) { require.NoError(t, err) // the primary should no longer be the deadNode - assert.NotEqual(t, c.Primary(), deadNode) + assert.NotEqual(t, c.Primary(), mockDeadNode) // we should still have the dead node as a witness because it // hasn't repeatedly been unresponsive yet assert.Equal(t, 2, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) } func TestClient_BackwardsVerification(t *testing.T) { { - trustHeader, _ := largeFullNode.LightBlock(ctx, 6) + headers, vals, _ := genLightBlocksWithKeys(chainID, 9, 3, 0, bTime) + delete(headers, 1) + delete(headers, 2) + delete(vals, 1) + delete(vals, 2) + mockLargeFullNode := mockNodeFromHeadersAndVals(headers, vals) + trustHeader, _ := mockLargeFullNode.LightBlock(ctx, 6) + c, err := light.NewClient( ctx, chainID, @@ -623,8 +646,8 @@ func TestClient_BackwardsVerification(t *testing.T) { Height: trustHeader.Height, Hash: trustHeader.Hash(), }, - largeFullNode, - []provider.Provider{largeFullNode}, + mockLargeFullNode, + []provider.Provider{mockLargeFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -662,41 +685,36 @@ func TestClient_BackwardsVerification(t *testing.T) { // so expect error _, err = c.VerifyLightBlockAtHeight(ctx, 8, bTime.Add(12*time.Minute)) assert.Error(t, err) + mockLargeFullNode.AssertExpectations(t) } { testCases := []struct { - provider provider.Provider + headers map[int64]*types.SignedHeader + vals map[int64]*types.ValidatorSet }{ { // 7) provides incorrect height - mockp.New( - chainID, - map[int64]*types.SignedHeader{ - 1: h1, - 2: keys.GenSignedHeader(chainID, 1, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - 3: h3, - }, - valSet, - ), + headers: map[int64]*types.SignedHeader{ + 2: keys.GenSignedHeader(chainID, 1, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + 3: h3, + }, + vals: valSet, }, { // 8) provides incorrect hash - mockp.New( - chainID, - map[int64]*types.SignedHeader{ - 1: h1, - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), - 3: h3, - }, - valSet, - ), + headers: map[int64]*types.SignedHeader{ + 2: keys.GenSignedHeader(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), + 3: h3, + }, + vals: valSet, }, } for idx, tc := range testCases { + mockNode := mockNodeFromHeadersAndVals(tc.headers, tc.vals) c, err := light.NewClient( ctx, chainID, @@ -705,8 +723,8 @@ func TestClient_BackwardsVerification(t *testing.T) { Height: 3, Hash: h3.Hash(), }, - tc.provider, - []provider.Provider{tc.provider}, + mockNode, + []provider.Provider{mockNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -714,6 +732,7 @@ func TestClient_BackwardsVerification(t *testing.T) { _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour).Add(1*time.Second)) assert.Error(t, err, idx) + mockNode.AssertExpectations(t) } } } @@ -723,60 +742,62 @@ func TestClient_NewClientFromTrustedStore(t *testing.T) { db := dbs.New(dbm.NewMemDB()) err := db.SaveLightBlock(l1) require.NoError(t, err) + mockNode := &provider_mocks.Provider{} c, err := light.NewClientFromTrustedStore( chainID, trustPeriod, - deadNode, - []provider.Provider{deadNode}, + mockNode, + []provider.Provider{mockNode}, db, ) require.NoError(t, err) - // 2) Check light block exists (deadNode is being used to ensure we're not getting - // it from primary) + // 2) Check light block exists h, err := c.TrustedLightBlock(1) assert.NoError(t, err) assert.EqualValues(t, l1.Height, h.Height) + mockNode.AssertExpectations(t) } func TestClientRemovesWitnessIfItSendsUsIncorrectHeader(t *testing.T) { // different headers hash then primary plus less than 1/3 signed (no fork) - badProvider1 := mockp.New( - chainID, - map[int64]*types.SignedHeader{ - 1: h1, - 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash"), hash("results_hash"), - len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - }, - ) - // header is empty - badProvider2 := mockp.New( - chainID, - map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - }, - ) + headers1 := map[int64]*types.SignedHeader{ + 1: h1, + 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash"), hash("results_hash"), + len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), + } + vals1 := map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + } + mockBadNode1 := mockNodeFromHeadersAndVals(headers1, vals1) + mockBadNode1.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - lb1, _ := badProvider1.LightBlock(ctx, 2) + // header is empty + headers2 := map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, + } + vals2 := map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + } + mockBadNode2 := mockNodeFromHeadersAndVals(headers2, vals2) + mockBadNode2.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) + + lb1, _ := mockBadNode1.LightBlock(ctx, 2) require.NotEqual(t, lb1.Hash(), l1.Hash()) c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{badProvider1, badProvider2}, + mockFullNode, + []provider.Provider{mockBadNode1, mockBadNode2}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -798,12 +819,13 @@ func TestClientRemovesWitnessIfItSendsUsIncorrectHeader(t *testing.T) { } // witness does not have a light block -> left in the list assert.EqualValues(t, 1, len(c.Witnesses())) + mockBadNode1.AssertExpectations(t) + mockBadNode2.AssertExpectations(t) } func TestClient_TrustedValidatorSet(t *testing.T) { differentVals, _ := factory.RandValidatorSet(10, 100) - badValSetNode := mockp.New( - chainID, + mockBadValSetNode := mockNodeFromHeadersAndVals( map[int64]*types.SignedHeader{ 1: h1, // 3/3 signed, but validator set at height 2 below is invalid -> witness @@ -811,21 +833,27 @@ func TestClient_TrustedValidatorSet(t *testing.T) { 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, hash("app_hash2"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h1.Hash()}), - 3: h3, }, map[int64]*types.ValidatorSet{ 1: vals, 2: differentVals, - 3: differentVals, + }) + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, }, - ) + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + }) c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{badValSetNode, fullNode}, + mockFullNode, + []provider.Provider{mockBadValSetNode, mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) @@ -835,15 +863,29 @@ func TestClient_TrustedValidatorSet(t *testing.T) { _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour).Add(1*time.Second)) assert.NoError(t, err) assert.Equal(t, 1, len(c.Witnesses())) + mockBadValSetNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) } func TestClientPrunesHeadersAndValidatorSets(t *testing.T) { + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + 0: h3, + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 3: vals, + 0: vals, + }) + c, err := light.NewClient( ctx, chainID, trustOptions, - fullNode, - []provider.Provider{fullNode}, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), light.PruningSize(1), @@ -858,6 +900,7 @@ func TestClientPrunesHeadersAndValidatorSets(t *testing.T) { _, err = c.TrustedLightBlock(1) assert.Error(t, err) + mockFullNode.AssertExpectations(t) } func TestClientEnsureValidHeadersAndValSets(t *testing.T) { @@ -869,86 +912,108 @@ func TestClientEnsureValidHeadersAndValSets(t *testing.T) { testCases := []struct { headers map[int64]*types.SignedHeader vals map[int64]*types.ValidatorSet - err bool + + errorToThrow error + errorHeight int64 + + err bool }{ { - headerSet, - valSet, - false, - }, - { - headerSet, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: nil, - }, - true, - }, - { - map[int64]*types.SignedHeader{ + headers: map[int64]*types.SignedHeader{ 1: h1, - 2: h2, - 3: nil, + 3: h3, }, - valSet, - true, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + 3: vals, + }, + err: false, }, { - headerSet, - map[int64]*types.ValidatorSet{ + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + vals: valSet, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + }, + vals: map[int64]*types.ValidatorSet{ 1: vals, - 2: vals, 3: emptyValSet, }, - true, + err: true, }, } - for _, tc := range testCases { - badNode := mockp.New( - chainID, - tc.headers, - tc.vals, - ) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - badNode, - []provider.Provider{badNode, badNode}, - dbs.New(dbm.NewMemDB()), - ) - require.NoError(t, err) + for i, tc := range testCases { + testCase := tc + t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { + mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) + if testCase.errorToThrow != nil { + mockBadNode.On("LightBlock", mock.Anything, testCase.errorHeight).Return(nil, testCase.errorToThrow) + } - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - if tc.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockBadNode, + []provider.Provider{mockBadNode, mockBadNode}, + dbs.New(dbm.NewMemDB()), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + if testCase.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + mockBadNode.AssertExpectations(t) + }) } } func TestClientHandlesContexts(t *testing.T) { - p := mockp.New(genMockNode(chainID, 100, 10, 1, bTime)) - genBlock, err := p.LightBlock(ctx, 1) - require.NoError(t, err) + mockNode := &provider_mocks.Provider{} + mockNode.On("LightBlock", + mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == nil }), + int64(1)).Return(l1, nil) + mockNode.On("LightBlock", + mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == context.DeadlineExceeded }), + mock.Anything).Return(nil, context.DeadlineExceeded) + + mockNode.On("LightBlock", + mock.MatchedBy(func(ctx context.Context) bool { return ctx.Err() == context.Canceled }), + mock.Anything).Return(nil, context.Canceled) // instantiate the light client with a timeout - ctxTimeOut, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + ctxTimeOut, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) defer cancel() - _, err = light.NewClient( + _, err := light.NewClient( ctxTimeOut, chainID, - light.TrustOptions{ - Period: 24 * time.Hour, - Height: 1, - Hash: genBlock.Hash(), - }, - p, - []provider.Provider{p, p}, + trustOptions, + mockNode, + []provider.Provider{mockNode, mockNode}, dbs.New(dbm.NewMemDB()), ) require.Error(t, ctxTimeOut.Err()) @@ -959,19 +1024,15 @@ func TestClientHandlesContexts(t *testing.T) { c, err := light.NewClient( ctx, chainID, - light.TrustOptions{ - Period: 24 * time.Hour, - Height: 1, - Hash: genBlock.Hash(), - }, - p, - []provider.Provider{p, p}, + trustOptions, + mockNode, + []provider.Provider{mockNode, mockNode}, dbs.New(dbm.NewMemDB()), ) require.NoError(t, err) // verify a block with a timeout - ctxTimeOutBlock, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + ctxTimeOutBlock, cancel := context.WithTimeout(ctx, 1*time.Nanosecond) defer cancel() _, err = c.VerifyLightBlockAtHeight(ctxTimeOutBlock, 100, bTime.Add(100*time.Minute)) require.Error(t, ctxTimeOutBlock.Err()) @@ -980,11 +1041,11 @@ func TestClientHandlesContexts(t *testing.T) { // verify a block with a cancel ctxCancel, cancel := context.WithCancel(ctx) - defer cancel() - time.AfterFunc(10*time.Millisecond, cancel) + cancel() _, err = c.VerifyLightBlockAtHeight(ctxCancel, 100, bTime.Add(100*time.Minute)) require.Error(t, ctxCancel.Err()) require.Error(t, err) require.True(t, errors.Is(err, context.Canceled)) + mockNode.AssertExpectations(t) } diff --git a/light/detector_test.go b/light/detector_test.go index 48efd4130..0be4b6ab5 100644 --- a/light/detector_test.go +++ b/light/detector_test.go @@ -1,10 +1,12 @@ package light_test import ( + "bytes" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" dbm "github.com/tendermint/tm-db" @@ -12,7 +14,7 @@ import ( "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/light" "github.com/tendermint/tendermint/light/provider" - mockp "github.com/tendermint/tendermint/light/provider/mock" + provider_mocks "github.com/tendermint/tendermint/light/provider/mocks" dbs "github.com/tendermint/tendermint/light/store/db" "github.com/tendermint/tendermint/types" ) @@ -20,18 +22,17 @@ import ( func TestLightClientAttackEvidence_Lunatic(t *testing.T) { // primary performs a lunatic attack var ( - latestHeight = int64(10) + latestHeight = int64(3) valSize = 5 - divergenceHeight = int64(6) + divergenceHeight = int64(2) primaryHeaders = make(map[int64]*types.SignedHeader, latestHeight) primaryValidators = make(map[int64]*types.ValidatorSet, latestHeight) ) - witnessHeaders, witnessValidators, chainKeys := genMockNodeWithKeys(chainID, latestHeight, valSize, 2, bTime) - witness := mockp.New(chainID, witnessHeaders, witnessValidators) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + forgedKeys := chainKeys[divergenceHeight-1].ChangeKeys(3) // we change 3 out of the 5 validators (still 2/5 remain) forgedVals := forgedKeys.ToValidators(2, 0) - for height := int64(1); height <= latestHeight; height++ { if height < divergenceHeight { primaryHeaders[height] = witnessHeaders[height] @@ -42,7 +43,38 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { nil, forgedVals, forgedVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(forgedKeys)) primaryValidators[height] = forgedVals } - primary := mockp.New(chainID, primaryHeaders, primaryValidators) + + // never called, delete it to make mockery asserts pass + delete(witnessHeaders, 2) + delete(primaryHeaders, 2) + + mockWitness := mockNodeFromHeadersAndVals(witnessHeaders, witnessValidators) + mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryValidators) + + mockWitness.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { + evAgainstPrimary := &types.LightClientAttackEvidence{ + // after the divergence height the valset doesn't change so we expect the evidence to be for the latest height + ConflictingBlock: &types.LightBlock{ + SignedHeader: primaryHeaders[latestHeight], + ValidatorSet: primaryValidators[latestHeight], + }, + CommonHeight: 1, + } + return bytes.Equal(evidence.Hash(), evAgainstPrimary.Hash()) + })).Return(nil) + + mockPrimary.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { + evAgainstWitness := &types.LightClientAttackEvidence{ + // when forming evidence against witness we learn that the canonical chain continued to change validator sets + // hence the conflicting block is at 7 + ConflictingBlock: &types.LightBlock{ + SignedHeader: witnessHeaders[divergenceHeight+1], + ValidatorSet: witnessValidators[divergenceHeight+1], + }, + CommonHeight: divergenceHeight - 1, + } + return bytes.Equal(evidence.Hash(), evAgainstWitness.Hash()) + })).Return(nil) c, err := light.NewClient( ctx, @@ -52,121 +84,132 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { Height: 1, Hash: primaryHeaders[1].Hash(), }, - primary, - []provider.Provider{witness}, + mockPrimary, + []provider.Provider{mockWitness}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) require.NoError(t, err) // Check verification returns an error. - _, err = c.VerifyLightBlockAtHeight(ctx, 10, bTime.Add(1*time.Hour)) + _, err = c.VerifyLightBlockAtHeight(ctx, latestHeight, bTime.Add(1*time.Hour)) if assert.Error(t, err) { assert.Equal(t, light.ErrLightClientAttack, err) } - // Check evidence was sent to both full nodes. - evAgainstPrimary := &types.LightClientAttackEvidence{ - // after the divergence height the valset doesn't change so we expect the evidence to be for height 10 - ConflictingBlock: &types.LightBlock{ - SignedHeader: primaryHeaders[10], - ValidatorSet: primaryValidators[10], - }, - CommonHeight: 4, - } - assert.True(t, witness.HasEvidence(evAgainstPrimary)) - - evAgainstWitness := &types.LightClientAttackEvidence{ - // when forming evidence against witness we learn that the canonical chain continued to change validator sets - // hence the conflicting block is at 7 - ConflictingBlock: &types.LightBlock{ - SignedHeader: witnessHeaders[7], - ValidatorSet: witnessValidators[7], - }, - CommonHeight: 4, - } - assert.True(t, primary.HasEvidence(evAgainstWitness)) + mockWitness.AssertExpectations(t) + mockPrimary.AssertExpectations(t) } func TestLightClientAttackEvidence_Equivocation(t *testing.T) { - verificationOptions := map[string]light.Option{ - "sequential": light.SequentialVerification(), - "skipping": light.SkippingVerification(light.DefaultTrustLevel), + cases := []struct { + name string + lightOption light.Option + unusedWitnessBlockHeights []int64 + unusedPrimaryBlockHeights []int64 + latestHeight int64 + divergenceHeight int64 + }{ + { + name: "sequential", + lightOption: light.SequentialVerification(), + unusedWitnessBlockHeights: []int64{4, 6}, + latestHeight: int64(5), + divergenceHeight: int64(3), + }, + { + name: "skipping", + lightOption: light.SkippingVerification(light.DefaultTrustLevel), + unusedWitnessBlockHeights: []int64{2, 4, 6}, + unusedPrimaryBlockHeights: []int64{2, 4, 6}, + latestHeight: int64(5), + divergenceHeight: int64(3), + }, } - for s, verificationOption := range verificationOptions { - t.Log("==> verification", s) - - // primary performs an equivocation attack - var ( - latestHeight = int64(10) - valSize = 5 - divergenceHeight = int64(6) - primaryHeaders = make(map[int64]*types.SignedHeader, latestHeight) - primaryValidators = make(map[int64]*types.ValidatorSet, latestHeight) - ) - // validators don't change in this network (however we still use a map just for convenience) - witnessHeaders, witnessValidators, chainKeys := genMockNodeWithKeys(chainID, latestHeight+2, valSize, 2, bTime) - witness := mockp.New(chainID, witnessHeaders, witnessValidators) - - for height := int64(1); height <= latestHeight; height++ { - if height < divergenceHeight { - primaryHeaders[height] = witnessHeaders[height] + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // primary performs an equivocation attack + var ( + valSize = 5 + primaryHeaders = make(map[int64]*types.SignedHeader, tc.latestHeight) + // validators don't change in this network (however we still use a map just for convenience) + primaryValidators = make(map[int64]*types.ValidatorSet, tc.latestHeight) + ) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, tc.latestHeight+1, valSize, 2, bTime) + for height := int64(1); height <= tc.latestHeight; height++ { + if height < tc.divergenceHeight { + primaryHeaders[height] = witnessHeaders[height] + primaryValidators[height] = witnessValidators[height] + continue + } + // we don't have a network partition so we will make 4/5 (greater than 2/3) malicious and vote again for + // a different block (which we do by adding txs) + primaryHeaders[height] = chainKeys[height].GenSignedHeader(chainID, height, + bTime.Add(time.Duration(height)*time.Minute), []types.Tx{[]byte("abcd")}, + witnessValidators[height], witnessValidators[height+1], hash("app_hash"), + hash("cons_hash"), hash("results_hash"), 0, len(chainKeys[height])-1) primaryValidators[height] = witnessValidators[height] - continue } - // we don't have a network partition so we will make 4/5 (greater than 2/3) malicious and vote again for - // a different block (which we do by adding txs) - primaryHeaders[height] = chainKeys[height].GenSignedHeader(chainID, height, - bTime.Add(time.Duration(height)*time.Minute), []types.Tx{[]byte("abcd")}, - witnessValidators[height], witnessValidators[height+1], hash("app_hash"), - hash("cons_hash"), hash("results_hash"), 0, len(chainKeys[height])-1) - primaryValidators[height] = witnessValidators[height] - } - primary := mockp.New(chainID, primaryHeaders, primaryValidators) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ - Period: 4 * time.Hour, - Height: 1, - Hash: primaryHeaders[1].Hash(), - }, - primary, - []provider.Provider{witness}, - dbs.New(dbm.NewMemDB()), - light.Logger(log.TestingLogger()), - verificationOption, - ) - require.NoError(t, err) + for _, height := range tc.unusedWitnessBlockHeights { + delete(witnessHeaders, height) + } + mockWitness := mockNodeFromHeadersAndVals(witnessHeaders, witnessValidators) + for _, height := range tc.unusedPrimaryBlockHeights { + delete(primaryHeaders, height) + } + mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryValidators) - // Check verification returns an error. - _, err = c.VerifyLightBlockAtHeight(ctx, 10, bTime.Add(1*time.Hour)) - if assert.Error(t, err) { - assert.Equal(t, light.ErrLightClientAttack, err) - } + // Check evidence was sent to both full nodes. + // Common height should be set to the height of the divergent header in the instance + // of an equivocation attack and the validator sets are the same as what the witness has + mockWitness.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { + evAgainstPrimary := &types.LightClientAttackEvidence{ + ConflictingBlock: &types.LightBlock{ + SignedHeader: primaryHeaders[tc.divergenceHeight], + ValidatorSet: primaryValidators[tc.divergenceHeight], + }, + CommonHeight: tc.divergenceHeight, + } + return bytes.Equal(evidence.Hash(), evAgainstPrimary.Hash()) + })).Return(nil) + mockPrimary.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { + evAgainstWitness := &types.LightClientAttackEvidence{ + ConflictingBlock: &types.LightBlock{ + SignedHeader: witnessHeaders[tc.divergenceHeight], + ValidatorSet: witnessValidators[tc.divergenceHeight], + }, + CommonHeight: tc.divergenceHeight, + } + return bytes.Equal(evidence.Hash(), evAgainstWitness.Hash()) + })).Return(nil) - // Check evidence was sent to both full nodes. - // Common height should be set to the height of the divergent header in the instance - // of an equivocation attack and the validator sets are the same as what the witness has - evAgainstPrimary := &types.LightClientAttackEvidence{ - ConflictingBlock: &types.LightBlock{ - SignedHeader: primaryHeaders[divergenceHeight], - ValidatorSet: primaryValidators[divergenceHeight], - }, - CommonHeight: divergenceHeight, - } - assert.True(t, witness.HasEvidence(evAgainstPrimary)) + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: 1, + Hash: primaryHeaders[1].Hash(), + }, + mockPrimary, + []provider.Provider{mockWitness}, + dbs.New(dbm.NewMemDB()), + light.Logger(log.TestingLogger()), + tc.lightOption, + ) + require.NoError(t, err) - evAgainstWitness := &types.LightClientAttackEvidence{ - ConflictingBlock: &types.LightBlock{ - SignedHeader: witnessHeaders[divergenceHeight], - ValidatorSet: witnessValidators[divergenceHeight], - }, - CommonHeight: divergenceHeight, - } - assert.True(t, primary.HasEvidence(evAgainstWitness)) + // Check verification returns an error. + _, err = c.VerifyLightBlockAtHeight(ctx, tc.latestHeight, bTime.Add(1*time.Hour)) + if assert.Error(t, err) { + assert.Equal(t, light.ErrLightClientAttack, err) + } + + mockWitness.AssertExpectations(t) + mockPrimary.AssertExpectations(t) + }) } } @@ -182,7 +225,10 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { primaryValidators = make(map[int64]*types.ValidatorSet, forgedHeight) ) - witnessHeaders, witnessValidators, chainKeys := genMockNodeWithKeys(chainID, latestHeight, valSize, 2, bTime) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + for _, unusedHeader := range []int64{3, 5, 6, 8} { + delete(primaryHeaders, unusedHeader) + } // primary has the exact same headers except it forges one extra header in the future using keys from 2/5ths of // the validators @@ -204,15 +250,36 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { hash("results_hash"), 0, len(forgedKeys), ) + mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryValidators) + lastBlock, _ := mockPrimary.LightBlock(ctx, forgedHeight) + mockPrimary.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) + mockPrimary.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - witness := mockp.New(chainID, witnessHeaders, witnessValidators) - primary := mockp.New(chainID, primaryHeaders, primaryValidators) + /* + for _, unusedHeader := range []int64{3, 5, 6, 8} { + delete(witnessHeaders, unusedHeader) + } + */ + mockWitness := mockNodeFromHeadersAndVals(witnessHeaders, witnessValidators) + lastBlock, _ = mockWitness.LightBlock(ctx, latestHeight) + mockWitness.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil).Once() + mockWitness.On("LightBlock", mock.Anything, int64(12)).Return(nil, provider.ErrHeightTooHigh) - laggingWitness := witness.Copy("laggingWitness") + mockWitness.On("ReportEvidence", mock.Anything, mock.MatchedBy(func(evidence types.Evidence) bool { + // Check evidence was sent to the witness against the full node + evAgainstPrimary := &types.LightClientAttackEvidence{ + ConflictingBlock: &types.LightBlock{ + SignedHeader: primaryHeaders[forgedHeight], + ValidatorSet: primaryValidators[forgedHeight], + }, + CommonHeight: latestHeight, + } + return bytes.Equal(evidence.Hash(), evAgainstPrimary.Hash()) + })).Return(nil).Twice() // In order to perform the attack, the primary needs at least one accomplice as a witness to also // send the forged block - accomplice := primary + accomplice := mockPrimary c, err := light.NewClient( ctx, @@ -222,8 +289,8 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { Height: 1, Hash: primaryHeaders[1].Hash(), }, - primary, - []provider.Provider{witness, accomplice}, + mockPrimary, + []provider.Provider{mockWitness, accomplice}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), light.MaxClockDrift(1*time.Second), @@ -251,7 +318,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { } go func() { time.Sleep(2 * time.Second) - witness.AddLightBlock(newLb) + mockWitness.On("LightBlock", mock.Anything, int64(0)).Return(newLb, nil) }() // Now assert that verification returns an error. We craft the light clients time to be a little ahead of the chain @@ -261,26 +328,19 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { assert.Equal(t, light.ErrLightClientAttack, err) } - // Check evidence was sent to the witness against the full node - evAgainstPrimary := &types.LightClientAttackEvidence{ - ConflictingBlock: &types.LightBlock{ - SignedHeader: primaryHeaders[forgedHeight], - ValidatorSet: primaryValidators[forgedHeight], - }, - CommonHeight: latestHeight, - } - assert.True(t, witness.HasEvidence(evAgainstPrimary)) - // We attempt the same call but now the supporting witness has a block which should // immediately conflict in time with the primary _, err = c.VerifyLightBlockAtHeight(ctx, forgedHeight, bTime.Add(time.Duration(forgedHeight)*time.Minute)) if assert.Error(t, err) { assert.Equal(t, light.ErrLightClientAttack, err) } - assert.True(t, witness.HasEvidence(evAgainstPrimary)) // Lastly we test the unfortunate case where the light clients supporting witness doesn't update // in enough time + mockLaggingWitness := mockNodeFromHeadersAndVals(witnessHeaders, witnessValidators) + mockLaggingWitness.On("LightBlock", mock.Anything, int64(12)).Return(nil, provider.ErrHeightTooHigh) + lastBlock, _ = mockLaggingWitness.LightBlock(ctx, latestHeight) + mockLaggingWitness.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) c, err = light.NewClient( ctx, chainID, @@ -289,8 +349,8 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { Height: 1, Hash: primaryHeaders[1].Hash(), }, - primary, - []provider.Provider{laggingWitness, accomplice}, + mockPrimary, + []provider.Provider{mockLaggingWitness, accomplice}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), light.MaxClockDrift(1*time.Second), @@ -300,17 +360,20 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { _, err = c.Update(ctx, bTime.Add(time.Duration(forgedHeight)*time.Minute)) assert.NoError(t, err) - + mockPrimary.AssertExpectations(t) + mockWitness.AssertExpectations(t) } // 1. Different nodes therefore a divergent header is produced. // => light client returns an error upon creation because primary and witness // have a different view. func TestClientDivergentTraces1(t *testing.T) { - primary := mockp.New(genMockNode(chainID, 10, 5, 2, bTime)) - firstBlock, err := primary.LightBlock(ctx, 1) + headers, vals, _ := genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + mockPrimary := mockNodeFromHeadersAndVals(headers, vals) + firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - witness := mockp.New(genMockNode(chainID, 10, 5, 2, bTime)) + headers, vals, _ = genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + mockWitness := mockNodeFromHeadersAndVals(headers, vals) _, err = light.NewClient( ctx, @@ -320,20 +383,25 @@ func TestClientDivergentTraces1(t *testing.T) { Hash: firstBlock.Hash(), Period: 4 * time.Hour, }, - primary, - []provider.Provider{witness}, + mockPrimary, + []provider.Provider{mockWitness}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) require.Error(t, err) assert.Contains(t, err.Error(), "does not match primary") + mockWitness.AssertExpectations(t) + mockPrimary.AssertExpectations(t) } // 2. Two out of three nodes don't respond but the third has a header that matches // => verification should be successful and all the witnesses should remain func TestClientDivergentTraces2(t *testing.T) { - primary := mockp.New(genMockNode(chainID, 10, 5, 2, bTime)) - firstBlock, err := primary.LightBlock(ctx, 1) + headers, vals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + mockPrimaryNode := mockNodeFromHeadersAndVals(headers, vals) + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) + firstBlock, err := mockPrimaryNode.LightBlock(ctx, 1) require.NoError(t, err) c, err := light.NewClient( ctx, @@ -343,31 +411,33 @@ func TestClientDivergentTraces2(t *testing.T) { Hash: firstBlock.Hash(), Period: 4 * time.Hour, }, - primary, - []provider.Provider{deadNode, deadNode, primary}, + mockPrimaryNode, + []provider.Provider{mockDeadNode, mockDeadNode, mockPrimaryNode}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 10, bTime.Add(1*time.Hour)) + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) assert.NoError(t, err) assert.Equal(t, 3, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockPrimaryNode.AssertExpectations(t) } // 3. witness has the same first header, but different second header // => creation should succeed, but the verification should fail func TestClientDivergentTraces3(t *testing.T) { - _, primaryHeaders, primaryVals := genMockNode(chainID, 10, 5, 2, bTime) - primary := mockp.New(chainID, primaryHeaders, primaryVals) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - firstBlock, err := primary.LightBlock(ctx, 1) + firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - _, mockHeaders, mockVals := genMockNode(chainID, 10, 5, 2, bTime) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) mockHeaders[1] = primaryHeaders[1] mockVals[1] = primaryVals[1] - witness := mockp.New(chainID, mockHeaders, mockVals) + mockWitness := mockNodeFromHeadersAndVals(mockHeaders, mockVals) c, err := light.NewClient( ctx, @@ -377,33 +447,33 @@ func TestClientDivergentTraces3(t *testing.T) { Hash: firstBlock.Hash(), Period: 4 * time.Hour, }, - primary, - []provider.Provider{witness}, + mockPrimary, + []provider.Provider{mockWitness}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 10, bTime.Add(1*time.Hour)) + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) assert.Error(t, err) assert.Equal(t, 1, len(c.Witnesses())) + mockWitness.AssertExpectations(t) + mockPrimary.AssertExpectations(t) } // 4. Witness has a divergent header but can not produce a valid trace to back it up. // It should be ignored func TestClientDivergentTraces4(t *testing.T) { - _, primaryHeaders, primaryVals := genMockNode(chainID, 10, 5, 2, bTime) - primary := mockp.New(chainID, primaryHeaders, primaryVals) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) - firstBlock, err := primary.LightBlock(ctx, 1) + firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - _, mockHeaders, mockVals := genMockNode(chainID, 10, 5, 2, bTime) - witness := primary.Copy("witness") - witness.AddLightBlock(&types.LightBlock{ - SignedHeader: mockHeaders[10], - ValidatorSet: mockVals[10], - }) + witnessHeaders, witnessVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + primaryHeaders[2] = witnessHeaders[2] + primaryVals[2] = witnessVals[2] + mockWitness := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) c, err := light.NewClient( ctx, @@ -413,14 +483,16 @@ func TestClientDivergentTraces4(t *testing.T) { Hash: firstBlock.Hash(), Period: 4 * time.Hour, }, - primary, - []provider.Provider{witness}, + mockPrimary, + []provider.Provider{mockWitness}, dbs.New(dbm.NewMemDB()), light.Logger(log.TestingLogger()), ) require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 10, bTime.Add(1*time.Hour)) + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) assert.Error(t, err) assert.Equal(t, 1, len(c.Witnesses())) + mockWitness.AssertExpectations(t) + mockPrimary.AssertExpectations(t) } diff --git a/light/helpers_test.go b/light/helpers_test.go index 2ca951913..a110c295d 100644 --- a/light/helpers_test.go +++ b/light/helpers_test.go @@ -3,10 +3,12 @@ package light_test import ( "time" + "github.com/stretchr/testify/mock" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/tmhash" tmtime "github.com/tendermint/tendermint/libs/time" + provider_mocks "github.com/tendermint/tendermint/light/provider/mocks" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/version" @@ -169,12 +171,12 @@ func (pkz privKeys) ChangeKeys(delta int) privKeys { return newKeys.Extend(delta) } -// Generates the header and validator set to create a full entire mock node with blocks to height ( -// blockSize) and with variation in validator sets. BlockIntervals are in per minute. +// genLightBlocksWithKeys generates the header and validator set to create +// blocks to height. BlockIntervals are in per minute. // NOTE: Expected to have a large validator set size ~ 100 validators. -func genMockNodeWithKeys( +func genLightBlocksWithKeys( chainID string, - blockSize int64, + numBlocks int64, valSize int, valVariation float32, bTime time.Time) ( @@ -183,9 +185,9 @@ func genMockNodeWithKeys( map[int64]privKeys) { var ( - headers = make(map[int64]*types.SignedHeader, blockSize) - valset = make(map[int64]*types.ValidatorSet, blockSize+1) - keymap = make(map[int64]privKeys, blockSize+1) + headers = make(map[int64]*types.SignedHeader, numBlocks) + valset = make(map[int64]*types.ValidatorSet, numBlocks+1) + keymap = make(map[int64]privKeys, numBlocks+1) keys = genPrivKeys(valSize) totalVariation = valVariation valVariationInt int @@ -207,12 +209,12 @@ func genMockNodeWithKeys( valset[1] = keys.ToValidators(2, 0) keys = newKeys - for height := int64(2); height <= blockSize; height++ { + for height := int64(2); height <= numBlocks; height++ { totalVariation += valVariation valVariationInt = int(totalVariation) totalVariation = -float32(valVariationInt) newKeys = keys.ChangeKeys(valVariationInt) - currentHeader = keys.GenSignedHeaderLastBlockID(chainID, height, bTime.Add(time.Duration(height)*time.Minute), + currentHeader = keys.GenSignedHeaderLastBlockID(chainID, height, bTime.Add(time.Duration(height)*time.Second), nil, keys.ToValidators(2, 0), newKeys.ToValidators(2, 0), hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: lastHeader.Hash()}) @@ -226,17 +228,14 @@ func genMockNodeWithKeys( return headers, valset, keymap } -func genMockNode( - chainID string, - blockSize int64, - valSize int, - valVariation float32, - bTime time.Time) ( - string, - map[int64]*types.SignedHeader, - map[int64]*types.ValidatorSet) { - headers, valset, _ := genMockNodeWithKeys(chainID, blockSize, valSize, valVariation, bTime) - return chainID, headers, valset +func mockNodeFromHeadersAndVals(headers map[int64]*types.SignedHeader, + vals map[int64]*types.ValidatorSet) *provider_mocks.Provider { + mockNode := &provider_mocks.Provider{} + for i, header := range headers { + lb := &types.LightBlock{SignedHeader: header, ValidatorSet: vals[i]} + mockNode.On("LightBlock", mock.Anything, i).Return(lb, nil) + } + return mockNode } func hash(s string) []byte { diff --git a/light/provider/mock/deadmock.go b/light/provider/mock/deadmock.go deleted file mode 100644 index 6045e45f6..000000000 --- a/light/provider/mock/deadmock.go +++ /dev/null @@ -1,30 +0,0 @@ -package mock - -import ( - "context" - "fmt" - - "github.com/tendermint/tendermint/light/provider" - "github.com/tendermint/tendermint/types" -) - -type deadMock struct { - id string -} - -// NewDeadMock creates a mock provider that always errors. id is used in case of multiple providers. -func NewDeadMock(id string) provider.Provider { - return &deadMock{id: id} -} - -func (p *deadMock) String() string { - return fmt.Sprintf("DeadMock-%s", p.id) -} - -func (p *deadMock) LightBlock(_ context.Context, height int64) (*types.LightBlock, error) { - return nil, provider.ErrNoResponse -} - -func (p *deadMock) ReportEvidence(_ context.Context, ev types.Evidence) error { - return provider.ErrNoResponse -} diff --git a/light/provider/mock/mock.go b/light/provider/mock/mock.go deleted file mode 100644 index fcb8a6fa4..000000000 --- a/light/provider/mock/mock.go +++ /dev/null @@ -1,125 +0,0 @@ -package mock - -import ( - "context" - "errors" - "fmt" - "strings" - "sync" - "time" - - "github.com/tendermint/tendermint/light/provider" - "github.com/tendermint/tendermint/types" -) - -type Mock struct { - id string - - mtx sync.Mutex - headers map[int64]*types.SignedHeader - vals map[int64]*types.ValidatorSet - evidenceToReport map[string]types.Evidence // hash => evidence - latestHeight int64 -} - -var _ provider.Provider = (*Mock)(nil) - -// New creates a mock provider with the given set of headers and validator -// sets. -func New(id string, headers map[int64]*types.SignedHeader, vals map[int64]*types.ValidatorSet) *Mock { - height := int64(0) - for h := range headers { - if h > height { - height = h - } - } - return &Mock{ - id: id, - headers: headers, - vals: vals, - evidenceToReport: make(map[string]types.Evidence), - latestHeight: height, - } -} - -func (p *Mock) String() string { - var headers strings.Builder - for _, h := range p.headers { - fmt.Fprintf(&headers, " %d:%X", h.Height, h.Hash()) - } - - var vals strings.Builder - for _, v := range p.vals { - fmt.Fprintf(&vals, " %X", v.Hash()) - } - - return fmt.Sprintf("Mock{id: %s, headers: %s, vals: %v}", p.id, headers.String(), vals.String()) -} - -func (p *Mock) LightBlock(ctx context.Context, height int64) (*types.LightBlock, error) { - p.mtx.Lock() - defer p.mtx.Unlock() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(10 * time.Millisecond): - } - - var lb *types.LightBlock - - if height > p.latestHeight { - return nil, provider.ErrHeightTooHigh - } - - if height == 0 && len(p.headers) > 0 { - height = p.latestHeight - } - - if _, ok := p.headers[height]; ok { - sh := p.headers[height] - vals := p.vals[height] - lb = &types.LightBlock{ - SignedHeader: sh, - ValidatorSet: vals, - } - } - if lb == nil { - return nil, provider.ErrLightBlockNotFound - } - if lb.SignedHeader == nil || lb.ValidatorSet == nil { - return nil, provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")} - } - if err := lb.ValidateBasic(lb.ChainID); err != nil { - return nil, provider.ErrBadLightBlock{Reason: err} - } - return lb, nil -} - -func (p *Mock) ReportEvidence(_ context.Context, ev types.Evidence) error { - p.evidenceToReport[string(ev.Hash())] = ev - return nil -} - -func (p *Mock) HasEvidence(ev types.Evidence) bool { - _, ok := p.evidenceToReport[string(ev.Hash())] - return ok -} - -func (p *Mock) AddLightBlock(lb *types.LightBlock) { - p.mtx.Lock() - defer p.mtx.Unlock() - - if err := lb.ValidateBasic(lb.ChainID); err != nil { - panic(fmt.Sprintf("unable to add light block, err: %v", err)) - } - p.headers[lb.Height] = lb.SignedHeader - p.vals[lb.Height] = lb.ValidatorSet - if lb.Height > p.latestHeight { - p.latestHeight = lb.Height - } -} - -func (p *Mock) Copy(id string) *Mock { - return New(id, p.headers, p.vals) -} diff --git a/light/provider/mocks/provider.go b/light/provider/mocks/provider.go new file mode 100644 index 000000000..5a58d6b32 --- /dev/null +++ b/light/provider/mocks/provider.go @@ -0,0 +1,53 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + types "github.com/tendermint/tendermint/types" +) + +// Provider is an autogenerated mock type for the Provider type +type Provider struct { + mock.Mock +} + +// LightBlock provides a mock function with given fields: ctx, height +func (_m *Provider) LightBlock(ctx context.Context, height int64) (*types.LightBlock, error) { + ret := _m.Called(ctx, height) + + var r0 *types.LightBlock + if rf, ok := ret.Get(0).(func(context.Context, int64) *types.LightBlock); ok { + r0 = rf(ctx, height) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.LightBlock) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, height) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReportEvidence provides a mock function with given fields: _a0, _a1 +func (_m *Provider) ReportEvidence(_a0 context.Context, _a1 types.Evidence) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.Evidence) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +}