diff --git a/light/client_test.go b/light/client_test.go index 73c79afba..b826eaead 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -2,6 +2,7 @@ package light_test import ( "context" + "errors" "sync" "testing" "time" @@ -929,3 +930,61 @@ func TestClientEnsureValidHeadersAndValSets(t *testing.T) { } } + +func TestClientHandlesContexts(t *testing.T) { + p := mockp.New(genMockNode(chainID, 100, 10, 1, bTime)) + genBlock, err := p.LightBlock(ctx, 1) + require.NoError(t, err) + + // instantiate the light client with a timeout + ctxTimeOut, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + _, err = light.NewClient( + ctxTimeOut, + chainID, + light.TrustOptions{ + Period: 24 * time.Hour, + Height: 1, + Hash: genBlock.Hash(), + }, + p, + []provider.Provider{p, p}, + dbs.New(dbm.NewMemDB()), + ) + require.Error(t, ctxTimeOut.Err()) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) + + // instantiate the client for real + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 24 * time.Hour, + Height: 1, + Hash: genBlock.Hash(), + }, + p, + []provider.Provider{p, p}, + dbs.New(dbm.NewMemDB()), + ) + require.NoError(t, err) + + // verify a block with a timeout + ctxTimeOutBlock, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + _, err = c.VerifyLightBlockAtHeight(ctxTimeOutBlock, 100, bTime.Add(100*time.Minute)) + require.Error(t, ctxTimeOutBlock.Err()) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) + + // verify a block with a cancel + ctxCancel, cancel := context.WithCancel(ctx) + defer cancel() + time.AfterFunc(10*time.Millisecond, cancel) + _, err = c.VerifyLightBlockAtHeight(ctxCancel, 100, bTime.Add(100*time.Minute)) + require.Error(t, ctxCancel.Err()) + require.Error(t, err) + require.True(t, errors.Is(err, context.Canceled)) + +} diff --git a/light/detector.go b/light/detector.go index 873ebe859..4d7301cb0 100644 --- a/light/detector.go +++ b/light/detector.go @@ -78,7 +78,10 @@ func (c *Client) detectDivergence(ctx context.Context, primaryTrace []*types.Lig "witness", c.witnesses[e.WitnessIndex], "err", err) witnessesToRemove = append(witnessesToRemove, e.WitnessIndex) default: - c.logger.Debug("error in light block request to witness", "err", err) + if errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded) { + return e + } + c.logger.Info("error in light block request to witness", "err", err) } } @@ -115,7 +118,7 @@ func (c *Client) compareNewHeaderWithWitness(ctx context.Context, errc chan erro // the witness hasn't been helpful in comparing headers, we mark the response and continue // comparing with the rest of the witnesses - case provider.ErrNoResponse, provider.ErrLightBlockNotFound: + case provider.ErrNoResponse, provider.ErrLightBlockNotFound, context.DeadlineExceeded, context.Canceled: errc <- err return diff --git a/light/provider/mock/mock.go b/light/provider/mock/mock.go index 939232404..fcb8a6fa4 100644 --- a/light/provider/mock/mock.go +++ b/light/provider/mock/mock.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/tendermint/tendermint/light/provider" "github.com/tendermint/tendermint/types" @@ -55,10 +56,16 @@ func (p *Mock) String() string { return fmt.Sprintf("Mock{id: %s, headers: %s, vals: %v}", p.id, headers.String(), vals.String()) } -func (p *Mock) LightBlock(_ context.Context, height int64) (*types.LightBlock, error) { +func (p *Mock) LightBlock(ctx context.Context, height int64) (*types.LightBlock, error) { p.mtx.Lock() defer p.mtx.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(10 * time.Millisecond): + } + var lb *types.LightBlock if height > p.latestHeight {