diff --git a/CHANGELOG.md b/CHANGELOG.md index d6d780cd3..31349a7cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -174,6 +174,21 @@ Special thanks to external contributors on this release: @JayT106, - [cmd/tendermint/commands] [\#6623](https://github.com/tendermint/tendermint/pull/6623) replace `$HOME/.some/test/dir` with `t.TempDir` (@tanyabouman) - [statesync] \6807 Implement P2P state provider as an alternative to RPC (@cmwaters) +## v0.34.15 + +Special thanks to external contributors on this release: @thanethomson + +### BUG FIXES + +- [\#7368](https://github.com/tendermint/tendermint/issues/7368) cmd: add integration test for rollback functionality (@cmwaters). +- [\#7309](https://github.com/tendermint/tendermint/issues/7309) pubsub: Report a non-nil error when shutting down (fixes #7306). +- [\#7057](https://github.com/tendermint/tendermint/pull/7057) Import Postgres driver support for the psql indexer (@creachadair). +- [\#7106](https://github.com/tendermint/tendermint/pull/7106) Revert mutex change to ABCI Clients (@tychoish). + +### IMPROVEMENTS + +- [config] [\#7230](https://github.com/tendermint/tendermint/issues/7230) rpc: Add experimental config params to allow for subscription buffer size control (@thanethomson). + ## v0.34.14 This release backports the `rollback` feature to allow recovery in the event of an incorrect app hash. diff --git a/docs/architecture/README.md b/docs/architecture/README.md index 7256e5038..1a97addfa 100644 --- a/docs/architecture/README.md +++ b/docs/architecture/README.md @@ -101,3 +101,4 @@ Note the context/background should be written in the present tense. - [ADR-057: RPC](./adr-057-RPC.md) - [ADR-069: Node Initialization](./adr-069-flexible-node-initialization.md) - [ADR-071: Proposer-Based Timestamps](adr-071-proposer-based-timestamps.md) +- [ADR-074: Migrate Timeout Parameters to Consensus Parameters](./adr-074-timeout-params.md) diff --git a/docs/architecture/adr-074-timeout-params.md b/docs/architecture/adr-074-timeout-params.md new file mode 100644 index 000000000..e3e1a1800 --- /dev/null +++ b/docs/architecture/adr-074-timeout-params.md @@ -0,0 +1,203 @@ +# ADR 74: Migrate Timeout Parameters to Consensus Parameters + +## Changelog + +- 03-Jan-2022: Initial draft (@williambanfield) +- 13-Jan-2022: Updated to indicate work on upgrade path needed (@williambanfield) + +## Status + +Proposed + +## Context + +### Background + +Tendermint's consensus timeout parameters are currently configured locally by each validator +in the validator's [config.toml][config-toml]. +This means that the validators on a Tendermint network may have different timeouts +from each other. There is no reason for validators on the same network to configure +different timeout values. Proper functioning of the Tendermint consensus algorithm +relies on these parameters being uniform across validators. + +The configurable values are as follows: + +* `TimeoutPropose` + * How long the consensus algorithm waits for a proposal block before issuing a prevote. + * If no prevote arrives by `TimeoutPropose`, then the consensus algorithm will issue a nil prevote. +* `TimeoutProposeDelta` + * How much the `TimeoutPropose` grows each round. +* `TimeoutPrevote` + * How long the consensus algorithm waits after receiving +2/3 prevotes with + no quorum for a value before issuing a precommit for nil. + (See the [arXiv paper][arxiv-paper], Algorithm 1, Line 34) +* `TimeoutPrevoteDelta` + * How much the `TimeoutPrevote` increases with each round. +* `TimeoutPrecommit` + * How long the consensus algorithm waits after receiving +2/3 precommits that + do not have a quorum for a value before entering the next round. + (See the [arXiv paper][arxiv-paper], Algorithm 1, Line 47) +* `TimeoutPrecommitDelta` + * How much the `TimeoutPrecommit` increases with each round. +* `TimeoutCommit` + * How long the consensus algorithm waits after committing a block but before starting the new height. + * This gives a validator a chance to receive slow precommits. +* `SkipTimeoutCommit` + * Make progress as soon as the node has 100% of the precommits. + + +### Overview of Change + +We will consolidate the timeout parameters and migrate them from the node-local +`config.toml` file into the network-global consensus parameters. + +The 8 timeout parameters will be consolidated down to 6. These will be as follows: + +* `TimeoutPropose` + * Same as current `TimeoutPropose`. +* `TimeoutProposeDelta` + * Same as current `TimeoutProposeDelta`. +* `TimeoutVote` + * How long validators wait for votes in both the prevote + and precommit phase of the consensus algorithm. This parameter subsumes + the current `TimeoutPrevote` and `TimeoutPrecommit` parameters. +* `TimeoutVoteDelta` + * How much the `TimeoutVote` will grow each successive round. + This parameter subsumes the current `TimeoutPrevoteDelta` and `TimeoutPrecommitDelta` + parameters. +* `TimeoutCommit` + * Same as current `TimeoutCommit`. +* `EnableTimeoutCommitBypass` + * Same as current `SkipTimeoutCommit`, renamed for clarity. + +A safe default will be provided by Tendermint for each of these parameters and +networks will be able to update the parameters as they see fit. Local updates +to these parameters will no longer be possible; instead, the application will control +updating the parameters. Applications using the Cosmos SDK will be automatically be +able to change the values of these consensus parameters [via a governance proposal][cosmos-sdk-consensus-params]. + +This change is low-risk. While parameters are locally configurable, many running chains +do not change them from their default values. For example, initializing +a node on Osmosis, Terra, and the Cosmos Hub using the their `init` command produces +a `config.toml` with Tendermint's default values for these parameters. + +### Why this parameter consolidation? + +Reducing the number of parameters is good for UX. Fewer superfluous parameters makes +running and operating a Tendermint network less confusing. + +The Prevote and Precommit messages are both similar sizes, require similar amounts +of processing so there is no strong need for them to be configured separately. + +The `TimeoutPropose` parameter governs how long Tendermint will wait for the proposed +block to be gossiped. Blocks are much larger than votes and therefore tend to be +gossiped much more slowly. It therefore makes sense to keep `TimeoutPropose` and +the `TimeoutProposeDelta` as parameters separate from the vote timeouts. + +`TimeoutCommit` is used by chains to ensure that the network waits for the votes from +slower validators before proceeding to the next height. Without this timeout, the votes +from slower validators would consistently not be included in blocks and those validators +would not be counted as 'up' from the chain's perspective. Being down damages a validator's +reputation and causes potential stakers to think twice before delegating to that validator. + +`TimeoutCommit` also prevents the network from producing the next height as soon as validators +on the fastest hardware with a summed voting power of +2/3 of the network's total have +completed execution of the block. Allowing the network to proceed as soon as the fastest ++2/3 completed execution would have a cumulative effect over heights, eventually +leaving slower validators unable to participate in consensus at all. `TimeoutCommit` +therefore allows networks to have greater variability in hardware. Additional +discussion of this can be found in [tendermint issue 5911][tendermint-issue-5911-comment] +and [spec issue 359][spec-issue-359]. + +## Alternative Approaches + +### Hardcode the parameters + +Many Tendermint networks run on similar cloud-hosted infrastructure. Therefore, +they have similar bandwidth and machine resources. The timings for propagating votes +and blocks are likely to be reasonably similar across networks. As a result, the +timeout parameters are good candidates for being hardcoded. Hardcoding the timeouts +in Tendermint would mean entirely removing these parameters from any configuration +that could be altered by either an application or a node operator. Instead, +Tendermint would ship with a set of timeouts and all applications using Tendermint +would use this exact same set of values. + +While Tendermint nodes often run with similar bandwidth and on similar cloud-hosted +machines, there are enough points of variability to make configuring +consensus timeouts meaningful. Namely, Tendermint network topologies are likely to be +very different from chain to chain. Additionally, applications may vary greatly in +how long the `Commit` phase may take. Applications that perform more work during `Commit` +require a longer `TimeoutCommit` to allow the application to complete its work +and be prepared for the next height. + +## Decision + +The decision has been made to implement this work, with the caveat that the +specific mechanism for introducing the new parameters to chains is still ongoing. + +## Detailed Design + +### New Consensus Parameters + +A new `TimeoutParams` `message` will be added to the [params.proto file][consensus-params-proto]. +This message will have the following form: + +```proto +message TimeoutParams { + google.protobuf.Duration propose = 1; + google.protobuf.Duration propose_delta = 2; + google.protobuf.Duration vote = 3; + google.protobuf.Duration vote_delta = 4; + google.protobuf.Duration commit = 5; + bool enable_commit_timeout_bypass = 6; +} +``` + +This new message will be added as a field into the [`ConsensusParams` +message][consensus-params-proto]. The same default values that are [currently +set for these parameters][current-timeout-defaults] in the local configuration +file will be used as the defaults for these new consensus parameters in the +[consensus parameter defaults][default-consensus-params]. + +The new consensus parameters will be subject to the same +[validity rules][time-param-validation] as the current configuration values, +namely, each value must be non-negative. + +### Migration + +The new `ConsensusParameters` will be added during an upcoming release. In this +release, the old `config.toml` parameters will cease to control the timeouts and +an error will be logged on nodes that continue to specify these values. The specific +mechanism by which these parameters will added to a chain is being discussed in +[RFC-009][rfc-009] and will be decided ahead of the next release. + +The specific mechanism for adding these parameters depends on work related to +[soft upgrades][soft-upgrades], which is still ongoing. + +## Consequences + +### Positive + +* Timeout parameters will be equal across all of the validators in a Tendermint network. +* Remove superfluous timeout parameters. + +### Negative + +### Neutral + +* Timeout parameters require consensus to change. + +## References + +[conseusus-params-proto]: https://github.com/tendermint/spec/blob/a00de7199f5558cdd6245bbbcd1d8405ccfb8129/proto/tendermint/types/params.proto#L11 +[hashed-params]: https://github.com/tendermint/tendermint/blob/7cdf560173dee6773b80d1c574a06489d4c394fe/types/params.go#L49 +[default-consensus-params]: https://github.com/tendermint/tendermint/blob/7cdf560173dee6773b80d1c574a06489d4c394fe/types/params.go#L79 +[current-timeout-defaults]: https://github.com/tendermint/tendermint/blob/7cdf560173dee6773b80d1c574a06489d4c394fe/config/config.go#L955 +[config-toml]: https://github.com/tendermint/tendermint/blob/5cc980698a3402afce76b26693ab54b8f67f038b/config/toml.go#L425-L440 +[cosmos-sdk-consensus-params]: https://github.com/cosmos/cosmos-sdk/issues/6197 +[time-param-validation]: https://github.com/tendermint/tendermint/blob/7cdf560173dee6773b80d1c574a06489d4c394fe/config/config.go#L1038 +[tendermint-issue-5911-comment]: https://github.com/tendermint/tendermint/issues/5911#issuecomment-973560381 +[spec-issue-359]: https://github.com/tendermint/spec/issues/359 +[arxiv-paper]: https://arxiv.org/pdf/1807.04938.pdf +[soft-upgrades]: https://github.com/tendermint/spec/pull/222 +[rfc-009]: https://github.com/tendermint/tendermint/pull/7524 diff --git a/docs/rfc/rfc-009-consensus-parameter-upgrades.md b/docs/rfc/rfc-009-consensus-parameter-upgrades.md new file mode 100644 index 000000000..60be878df --- /dev/null +++ b/docs/rfc/rfc-009-consensus-parameter-upgrades.md @@ -0,0 +1,128 @@ +# RFC 009 : Consensus Parameter Upgrade Considerations + +## Changelog + +- 06-Jan-2011: Initial draft (@williambanfield). + +## Abstract + +This document discusses the challenges of adding additional consensus parameters +to Tendermint and proposes a few solutions that can enable addition of consensus +parameters in a backwards-compatible way. + +## Background + +This section provides an overview of the issues of adding consensus parameters +to Tendermint. + +### Hash Compatibility + +Tendermint produces a hash of a subset of the consensus parameters. The values +that are hashed currently are the `BlockMaxGas` and the `BlockMaxSize`. These +are currently in the [HashedParams struct][hashed-params]. This hash is included +in the block and validators use it to validate that their local view of the consensus +parameters matches what the rest of the network is configured with. + +Any new consensus parameters added to Tendermint should be included in this +hash. This presents a challenge for verification of historical blocks when consensus +parameters are added. If a network produced blocks with a version of Tendermint that +did not yet have the new consensus parameters, the parameter hash it produced will +not reference the new parameters. Any nodes joining the network with the newer +version of Tendermint will have the new consensus parameters. Tendermint will need +to handle this case so that new versions of Tendermint with new consensus parameters +can still validate old blocks correctly without having to do anything overly complex +or hacky. + +### Allowing Developer-Defined Values and the `EndBlock` Problem + +When new consensus parameters are added, application developers may wish to set +values for them so that the developer-defined values may be used as soon as the +software upgrades. We do not currently have a clean mechanism for handling this. + +Consensus parameter updates are communicated from the application to Tendermint +within `EndBlock` of some height `H` and take effect at the next height, `H+1`. +This means that for updates that add a consensus parameter, there is a single +height where the new parameters cannot take effect. The parameters did not exist +in the version of the software that emitted the `EndBlock` response for height `H-1`, +so they cannot take effect at height `H`. The first height that the updated params +can take effect is height `H+1`. As of now, height `H` must run with the defaults. + +## Discussion + +### Hash Compatibility + +This section discusses possible solutions to the problem of maintaining backwards-compatibility +of hashed parameters while adding new parameters. + +#### Never Hash Defaults + +One solution to the problem of backwards-compatibility is to never include parameters +in the hash if the are using the default value. This means that blocks produced +before the parameters existed will have implicitly been created with the defaults. +This works because any software with newer versions of Tendermint must be using the +defaults for new parameters when validating old blocks since the defaults can not +have been updated until a height at which the parameters existed. + +#### Only Update HashedParams on Hash-Breaking Releases + +An alternate solution to never hashing defaults is to not update the hashed +parameters on non-hash-breaking releases. This means that when new consensus +parameters are added to Tendermint, there may be a release that makes use of the +parameters but does not verify that they are the same across all validators by +referencing them in the hash. This seems reasonably safe given the fact that +only a very far subset of the consensus parameters are currently verified at all. + +#### Version The Consensus Parameter Hash Scheme + +The upcoming work on [soft upgrades](https://github.com/tendermint/spec/pull/222) +proposes applying different hashing rules depending on the active block version. +The consensus parameter hash could be versioned in the same way. When different +block versions are used, a different set of consensus parameters will be included +in the hash. + +### Developer Defined Values + +This section discusses possible solutions to the problem of allowing application +developers to define values for the new parameters during the upgrade that adds +the parameters. + +#### Using `InitChain` for New Values + +One solution to the problem of allowing application developers to define values +for new consensus parameters is to call the `InitChain` ABCI method on application +startup and fetch the value for any new consensus parameters. The [response object][init-chain-response] +contains a field for `ConsensusParameter` updates so this may serve as a natural place +to put this logic. + +This poses a few difficulties. Nodes replaying old blocks while running new +software do not ever call `InitChain` after the initial time. They will therefore +not have a way to determine that the parameters changed at some height by using a +call to `InitChain`. The `EndBlock` response is how parameter changes at a height +are currently communicated to Tendermint and conflating these cases seems risky. + +#### Force Defaults For Single Height + +An alternate option is to not use `InitChain` and instead require chains to use the +default values of the new parameters for a single height. + +As documented in the upcoming [ADR-74][adr-74], popular chains often simply use the default +values. Additionally, great care is being taken to ensure that logic governed by upcoming +consensus parameters is not liveness-breaking. This means that, at worst-case, +chains will experience a single slow height while waiting for the new values to +by applied. + +#### Add a new `UpgradeChain` method + +An additional method for allowing chains to update the consensus parameters that +do not yet exist is to add a new `UpgradeChain` method to `ABCI`. The upgrade chain +method would be called when the chain detects that the version of block that it +is about to produce does not match the previous block. This method would be called +after `EndBlock` and would return the set of consensus parameters to use at the +next height. It would therefore give an application the chance to set the new +consensus parameters before running a height with these new parameter. + +### References + +[hashed-params]: https://github.com/tendermint/tendermint/blob/0ae974e63911804d4a2007bd8a9b3ad81d6d2a90/types/params.go#L49 +[init-chain-response]: https://github.com/tendermint/tendermint/blob/0ae974e63911804d4a2007bd8a9b3ad81d6d2a90/abci/types/types.pb.go#L1616 +[adr-74]: https://github.com/tendermint/tendermint/pull/7503 diff --git a/internal/consensus/replay_test.go b/internal/consensus/replay_test.go index 4a8335515..496565c78 100644 --- a/internal/consensus/replay_test.go +++ b/internal/consensus/replay_test.go @@ -769,7 +769,7 @@ func testHandshakeReplay( privVal, err := privval.LoadFilePV(cfg.PrivValidator.KeyFile(), cfg.PrivValidator.StateFile()) require.NoError(t, err) - wal, err := NewWAL(logger, walFile) + wal, err := NewWAL(ctx, logger, walFile) require.NoError(t, err) err = wal.Start(ctx) require.NoError(t, err) diff --git a/internal/consensus/state.go b/internal/consensus/state.go index d3d886907..356f9d57e 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -488,7 +488,7 @@ func (cs *State) Wait() { // OpenWAL opens a file to log all consensus messages and timeouts for // deterministic accountability. func (cs *State) OpenWAL(ctx context.Context, walFile string) (WAL, error) { - wal, err := NewWAL(cs.logger.With("wal", walFile), walFile) + wal, err := NewWAL(ctx, cs.logger.With("wal", walFile), walFile) if err != nil { cs.logger.Error("failed to open WAL", "file", walFile, "err", err) return nil, err diff --git a/internal/consensus/wal.go b/internal/consensus/wal.go index 59086e712..36993e762 100644 --- a/internal/consensus/wal.go +++ b/internal/consensus/wal.go @@ -90,13 +90,13 @@ var _ WAL = &BaseWAL{} // NewWAL returns a new write-ahead logger based on `baseWAL`, which implements // WAL. It's flushed and synced to disk every 2s and once when stopped. -func NewWAL(logger log.Logger, walFile string, groupOptions ...func(*auto.Group)) (*BaseWAL, error) { +func NewWAL(ctx context.Context, logger log.Logger, walFile string, groupOptions ...func(*auto.Group)) (*BaseWAL, error) { err := tmos.EnsureDir(filepath.Dir(walFile), 0700) if err != nil { return nil, fmt.Errorf("failed to ensure WAL directory is in place: %w", err) } - group, err := auto.OpenGroup(logger, walFile, groupOptions...) + group, err := auto.OpenGroup(ctx, logger, walFile, groupOptions...) if err != nil { return nil, err } diff --git a/internal/consensus/wal_test.go b/internal/consensus/wal_test.go index b52c41b9f..f686fece6 100644 --- a/internal/consensus/wal_test.go +++ b/internal/consensus/wal_test.go @@ -33,7 +33,7 @@ func TestWALTruncate(t *testing.T) { // defaultHeadSizeLimit(10M) is hard to simulate. // this magic number 1 * time.Millisecond make RotateFile check frequently. // defaultGroupCheckDuration(5s) is hard to simulate. - wal, err := NewWAL(logger, walFile, + wal, err := NewWAL(ctx, logger, walFile, autofile.GroupHeadSizeLimit(4096), autofile.GroupCheckDuration(1*time.Millisecond), ) @@ -103,7 +103,7 @@ func TestWALWrite(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - wal, err := NewWAL(log.TestingLogger(), walFile) + wal, err := NewWAL(ctx, log.TestingLogger(), walFile) require.NoError(t, err) err = wal.Start(ctx) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestWALSearchForEndHeight(t *testing.T) { } walFile := tempWALWithData(t, walBody) - wal, err := NewWAL(logger, walFile) + wal, err := NewWAL(ctx, logger, walFile) require.NoError(t, err) h := int64(3) @@ -163,13 +163,13 @@ func TestWALSearchForEndHeight(t *testing.T) { } func TestWALPeriodicSync(t *testing.T) { - walDir := t.TempDir() - walFile := filepath.Join(walDir, "wal") - wal, err := NewWAL(log.TestingLogger(), walFile, autofile.GroupCheckDuration(1*time.Millisecond)) - ctx, cancel := context.WithCancel(context.Background()) defer cancel() + walDir := t.TempDir() + walFile := filepath.Join(walDir, "wal") + wal, err := NewWAL(ctx, log.TestingLogger(), walFile, autofile.GroupCheckDuration(1*time.Millisecond)) + require.NoError(t, err) wal.SetFlushInterval(walTestFlushInterval) diff --git a/internal/libs/autofile/autofile.go b/internal/libs/autofile/autofile.go index 10cc04a28..0bc9a63a3 100644 --- a/internal/libs/autofile/autofile.go +++ b/internal/libs/autofile/autofile.go @@ -1,6 +1,7 @@ package autofile import ( + "context" "os" "os/signal" "path/filepath" @@ -57,7 +58,7 @@ type AutoFile struct { // OpenAutoFile creates an AutoFile in the path (with random ID). If there is // an error, it will be of type *PathError or *ErrPermissionsChanged (if file's // permissions got changed (should be 0600)). -func OpenAutoFile(path string) (*AutoFile, error) { +func OpenAutoFile(ctx context.Context, path string) (*AutoFile, error) { var err error path, err = filepath.Abs(path) if err != nil { @@ -78,12 +79,17 @@ func OpenAutoFile(path string) (*AutoFile, error) { af.hupc = make(chan os.Signal, 1) signal.Notify(af.hupc, syscall.SIGHUP) go func() { - for range af.hupc { - _ = af.closeFile() + for { + select { + case <-af.hupc: + _ = af.closeFile() + case <-ctx.Done(): + return + } } }() - go af.closeFileRoutine() + go af.closeFileRoutine(ctx) return af, nil } @@ -99,9 +105,12 @@ func (af *AutoFile) Close() error { return af.closeFile() } -func (af *AutoFile) closeFileRoutine() { +func (af *AutoFile) closeFileRoutine(ctx context.Context) { for { select { + case <-ctx.Done(): + _ = af.closeFile() + return case <-af.closeTicker.C: _ = af.closeFile() case <-af.closeTickerStopc: diff --git a/internal/libs/autofile/autofile_test.go b/internal/libs/autofile/autofile_test.go index 479a239cb..9864ed82a 100644 --- a/internal/libs/autofile/autofile_test.go +++ b/internal/libs/autofile/autofile_test.go @@ -1,6 +1,7 @@ package autofile import ( + "context" "os" "path/filepath" "syscall" @@ -12,6 +13,9 @@ import ( ) func TestSIGHUP(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + origDir, err := os.Getwd() require.NoError(t, err) t.Cleanup(func() { @@ -30,7 +34,7 @@ func TestSIGHUP(t *testing.T) { // Create an AutoFile in the temporary directory name := "sighup_test" - af, err := OpenAutoFile(name) + af, err := OpenAutoFile(ctx, name) require.NoError(t, err) require.True(t, filepath.IsAbs(af.Path)) @@ -104,13 +108,16 @@ func TestSIGHUP(t *testing.T) { // } func TestAutoFileSize(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // First, create an AutoFile writing to a tempfile dir f, err := os.CreateTemp("", "sighup_test") require.NoError(t, err) require.NoError(t, f.Close()) // Here is the actual AutoFile. - af, err := OpenAutoFile(f.Name()) + af, err := OpenAutoFile(ctx, f.Name()) require.NoError(t, err) // 1. Empty file diff --git a/internal/libs/autofile/cmd/logjack.go b/internal/libs/autofile/cmd/logjack.go index c246871dc..a9f6cf766 100644 --- a/internal/libs/autofile/cmd/logjack.go +++ b/internal/libs/autofile/cmd/logjack.go @@ -48,7 +48,7 @@ func main() { } // Open Group - group, err := auto.OpenGroup(log.NewNopLogger(), headPath, auto.GroupHeadSizeLimit(chopSize), auto.GroupTotalSizeLimit(limitSize)) + group, err := auto.OpenGroup(ctx, log.NewNopLogger(), headPath, auto.GroupHeadSizeLimit(chopSize), auto.GroupTotalSizeLimit(limitSize)) if err != nil { fmt.Printf("logjack couldn't create output file %v\n", headPath) os.Exit(1) diff --git a/internal/libs/autofile/group.go b/internal/libs/autofile/group.go index b8bbb78bd..0ffc2f04c 100644 --- a/internal/libs/autofile/group.go +++ b/internal/libs/autofile/group.go @@ -80,12 +80,12 @@ type Group struct { // OpenGroup creates a new Group with head at headPath. It returns an error if // it fails to open head file. -func OpenGroup(logger log.Logger, headPath string, groupOptions ...func(*Group)) (*Group, error) { +func OpenGroup(ctx context.Context, logger log.Logger, headPath string, groupOptions ...func(*Group)) (*Group, error) { dir, err := filepath.Abs(filepath.Dir(headPath)) if err != nil { return nil, err } - head, err := OpenAutoFile(headPath) + head, err := OpenAutoFile(ctx, headPath) if err != nil { return nil, err } diff --git a/internal/libs/autofile/group_test.go b/internal/libs/autofile/group_test.go index c4e068af9..328201780 100644 --- a/internal/libs/autofile/group_test.go +++ b/internal/libs/autofile/group_test.go @@ -1,6 +1,7 @@ package autofile import ( + "context" "io" "os" "path/filepath" @@ -14,14 +15,14 @@ import ( tmrand "github.com/tendermint/tendermint/libs/rand" ) -func createTestGroupWithHeadSizeLimit(t *testing.T, logger log.Logger, headSizeLimit int64) *Group { +func createTestGroupWithHeadSizeLimit(ctx context.Context, t *testing.T, logger log.Logger, headSizeLimit int64) *Group { testID := tmrand.Str(12) testDir := "_test_" + testID err := tmos.EnsureDir(testDir, 0700) require.NoError(t, err, "Error creating dir") headPath := testDir + "/myfile" - g, err := OpenGroup(logger, headPath, GroupHeadSizeLimit(headSizeLimit)) + g, err := OpenGroup(ctx, logger, headPath, GroupHeadSizeLimit(headSizeLimit)) require.NoError(t, err, "Error opening Group") require.NotEqual(t, nil, g, "Failed to create Group") @@ -43,9 +44,12 @@ func assertGroupInfo(t *testing.T, gInfo GroupInfo, minIndex, maxIndex int, tota } func TestCheckHeadSizeLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 1000*1000) + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 1000*1000) // At first, there are no files. assertGroupInfo(t, g.ReadGroupInfo(), 0, 0, 0, 0) @@ -114,7 +118,9 @@ func TestCheckHeadSizeLimit(t *testing.T) { func TestRotateFile(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) // Create a different temporary directory and move into it, to make sure // relative paths are resolved at Group creation @@ -180,7 +186,10 @@ func TestRotateFile(t *testing.T) { func TestWrite(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) written := []byte("Medusa") _, err := g.Write(written) @@ -205,7 +214,10 @@ func TestWrite(t *testing.T) { func TestGroupReaderRead(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) professor := []byte("Professor Monster") _, err := g.Write(professor) @@ -240,7 +252,10 @@ func TestGroupReaderRead(t *testing.T) { func TestGroupReaderRead2(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) professor := []byte("Professor Monster") _, err := g.Write(professor) @@ -276,7 +291,10 @@ func TestGroupReaderRead2(t *testing.T) { func TestMinIndex(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) assert.Zero(t, g.MinIndex(), "MinIndex should be zero at the beginning") @@ -286,7 +304,10 @@ func TestMinIndex(t *testing.T) { func TestMaxIndex(t *testing.T) { logger := log.TestingLogger() - g := createTestGroupWithHeadSizeLimit(t, logger, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g := createTestGroupWithHeadSizeLimit(ctx, t, logger, 0) assert.Zero(t, g.MaxIndex(), "MaxIndex should be zero at the beginning") diff --git a/internal/statesync/dispatcher.go b/internal/statesync/dispatcher.go index 2e476c25d..3f3e2a117 100644 --- a/internal/statesync/dispatcher.go +++ b/internal/statesync/dispatcher.go @@ -107,7 +107,7 @@ func (d *Dispatcher) dispatch(ctx context.Context, peer types.NodeID, height int // Respond allows the underlying process which receives requests on the // requestCh to respond with the respective light block. A nil response is used to // represent that the receiver of the request does not have a light block at that height. -func (d *Dispatcher) Respond(lb *tmproto.LightBlock, peer types.NodeID) error { +func (d *Dispatcher) Respond(ctx context.Context, lb *tmproto.LightBlock, peer types.NodeID) error { d.mtx.Lock() defer d.mtx.Unlock() @@ -121,8 +121,12 @@ func (d *Dispatcher) Respond(lb *tmproto.LightBlock, peer types.NodeID) error { // If lb is nil we take that to mean that the peer didn't have the requested light // block and thus pass on the nil to the caller. if lb == nil { - answerCh <- nil - return nil + select { + case answerCh <- nil: + return nil + case <-ctx.Done(): + return ctx.Err() + } } block, err := types.LightBlockFromProto(lb) @@ -130,8 +134,12 @@ func (d *Dispatcher) Respond(lb *tmproto.LightBlock, peer types.NodeID) error { return err } - answerCh <- block - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case answerCh <- block: + return nil + } } // Close shuts down the dispatcher and cancels any pending calls awaiting responses. @@ -139,9 +147,11 @@ func (d *Dispatcher) Respond(lb *tmproto.LightBlock, peer types.NodeID) error { func (d *Dispatcher) Close() { d.mtx.Lock() defer d.mtx.Unlock() - for peer, call := range d.calls { + for peer := range d.calls { delete(d.calls, peer) - close(call) + // don't close the channel here as it's closed in + // other handlers, and would otherwise get garbage + // collected. } } diff --git a/internal/statesync/dispatcher_test.go b/internal/statesync/dispatcher_test.go index 918c6ec9e..65c517be4 100644 --- a/internal/statesync/dispatcher_test.go +++ b/internal/statesync/dispatcher_test.go @@ -80,7 +80,7 @@ func TestDispatcherReturnsNoBlock(t *testing.T) { go func() { <-chans.Out - require.NoError(t, d.Respond(nil, peer)) + require.NoError(t, d.Respond(ctx, nil, peer)) cancel() }() @@ -309,7 +309,7 @@ func handleRequests(ctx context.Context, t *testing.T, d *Dispatcher, ch chan p2 peer := request.To resp := mockLBResp(ctx, t, peer, int64(height), time.Now()) block, _ := resp.block.ToProto() - require.NoError(t, d.Respond(block, resp.peer)) + require.NoError(t, d.Respond(ctx, block, resp.peer)) case <-ctx.Done(): return } diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index 78c9d8360..6ca0cb6b7 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -740,7 +740,10 @@ func (r *Reactor) handleLightBlockMessage(ctx context.Context, envelope *p2p.Env height = msg.LightBlock.SignedHeader.Header.Height } r.logger.Info("received light block response", "peer", envelope.From, "height", height) - if err := r.dispatcher.Respond(msg.LightBlock, envelope.From); err != nil { + if err := r.dispatcher.Respond(ctx, msg.LightBlock, envelope.From); err != nil { + if errors.Is(err, context.Canceled) { + return err + } r.logger.Error("error processing light block response", "err", err, "height", height) } diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index 161d8699a..ee1fc8c31 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -161,7 +161,7 @@ func setup( } } - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() var err error rts.reactor, err = NewReactor( diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index 1097fa233..59eb79766 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -2,6 +2,7 @@ package light_test import ( "context" + "errors" "testing" "time" @@ -57,14 +58,14 @@ func (impl *providerBenchmarkImpl) LightBlock(ctx context.Context, height int64) } func (impl *providerBenchmarkImpl) ReportEvidence(_ context.Context, _ types.Evidence) error { - panic("not implemented") + return errors.New("not implemented") } func BenchmarkSequence(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) @@ -101,7 +102,7 @@ func BenchmarkBisection(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) @@ -137,7 +138,7 @@ func BenchmarkBackwards(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) trustedBlock, _ := benchmarkFullNode.LightBlock(ctx, 0) diff --git a/light/client_test.go b/light/client_test.go index 1bea73c3e..b20946fb9 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -26,1089 +26,1089 @@ const ( chainID = "test" ) -var ( - keys = genPrivKeys(4) - vals = keys.ToValidators(20, 10) - bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - h1 = keys.GenSignedHeader(chainID, 1, bTime, nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) - // 3/3 signed - h2 = keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h1.Hash()}) - // 3/3 signed - h3 = keys.GenSignedHeaderLastBlockID(chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h2.Hash()}) - trustPeriod = 4 * time.Hour - trustOptions = light.TrustOptions{ - Period: 4 * time.Hour, - Height: 1, - Hash: h1.Hash(), - } - valSet = map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: vals, - 4: vals, - } - headerSet = map[int64]*types.SignedHeader{ - 1: h1, - // interim header (3/3 signed) - 2: h2, - // last header (3/3 signed) - 3: h3, - } - l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} - l2 = &types.LightBlock{SignedHeader: h2, ValidatorSet: vals} - l3 = &types.LightBlock{SignedHeader: h3, ValidatorSet: vals} -) +var bTime time.Time -func TestValidateTrustOptions(t *testing.T) { - testCases := []struct { - err bool - to light.TrustOptions - }{ - { - false, - trustOptions, - }, - { - true, - light.TrustOptions{ - Period: -1 * time.Hour, - Height: 1, - Hash: h1.Hash(), - }, - }, - { - true, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 0, - Hash: h1.Hash(), - }, - }, - { - true, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 1, - Hash: []byte("incorrect hash"), - }, - }, +func init() { + var err error + bTime, err = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") + if err != nil { + panic(err) } +} - for _, tc := range testCases { - err := tc.to.ValidateBasic() - if tc.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) +func TestClient(t *testing.T) { + var ( + keys = genPrivKeys(4) + vals = keys.ToValidators(20, 10) + trustPeriod = 4 * time.Hour + + valSet = map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: vals, + 4: vals, } - } -} - -func TestClient_SequentialVerification(t *testing.T) { - newKeys := genPrivKeys(4) - newVals := newKeys.ToValidators(10, 1) - differentVals, _ := factory.ValidatorSet(t, 10, 100) - - testCases := []struct { - name string - otherHeaders map[int64]*types.SignedHeader // all except ^ - vals map[int64]*types.ValidatorSet - initErr bool - verifyErr bool - }{ - { - "good", - headerSet, - valSet, - false, - false, - }, - { - "bad: different first header", - map[int64]*types.SignedHeader{ - // different header - 1: keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - }, - true, - false, - }, - { - "bad: no first signed header", - map[int64]*types.SignedHeader{}, - map[int64]*types.ValidatorSet{ - 1: differentVals, - }, - true, - true, - }, - { - "bad: different first validator set", - map[int64]*types.SignedHeader{ - 1: h1, - }, - map[int64]*types.ValidatorSet{ - 1: differentVals, - }, - true, - true, - }, - { - "bad: 1/3 signed interim header", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (1/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), - // last header (3/3 signed) - 3: keys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - }, - valSet, - false, - true, - }, - { - "bad: 1/3 signed last header", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (3/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - // last header (1/3 signed) - 3: keys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), - }, - valSet, - false, - true, - }, - { - "bad: different validator set at height 3", - headerSet, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, - }, - false, - true, - }, - } - - for _, tc := range testCases { - testCase := tc - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - logger := log.NewTestingLogger(t) - - mockNode := mockNodeFromHeadersAndVals(testCase.otherHeaders, testCase.vals) - mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockNode, - []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SequentialVerification(), - light.Logger(logger), - ) - - if testCase.initErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) - if testCase.verifyErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - mockNode.AssertExpectations(t) - }) - } -} - -func TestClient_SkippingVerification(t *testing.T) { - // required for 2nd test case - newKeys := genPrivKeys(4) - newVals := newKeys.ToValidators(10, 1) - - // 1/3+ of vals, 2/3- of newVals - transitKeys := keys.Extend(3) - transitVals := transitKeys.ToValidators(10, 1) - - testCases := []struct { - name string - otherHeaders map[int64]*types.SignedHeader // all except ^ - vals map[int64]*types.ValidatorSet - initErr bool - verifyErr bool - }{ - { - "good", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // last header (3/3 signed) - 3: h3, - }, - valSet, - false, - false, - }, - { - "good, but val set changes by 2/3 (1/3 of vals is still present)", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - 3: transitKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, transitVals, transitVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(transitKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: transitVals, - }, - false, - false, - }, - { - "good, but val set changes 100% at height 2", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (3/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - // last header (0/4 of the original val set signed) - 3: newKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, - }, - false, - false, - }, - { - "bad: last header signed by newVals, interim header has no signers", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // last header (0/4 of the original val set signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, 0), - // last header (0/4 of the original val set signed) - 3: newKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, - }, - false, - true, - }, - } - - bctx, bcancel := context.WithCancel(context.Background()) - defer bcancel() - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) - defer cancel() - logger := log.NewTestingLogger(t) - - mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) - mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockNode, - []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), - light.Logger(logger), - ) - if tc.initErr { - require.Error(t, err) - return - } - - require.NoError(t, err) - - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) - if tc.verifyErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } - -} - -// start from a large light block to make sure that the pivot height doesn't select a height outside -// the appropriate range -func TestClientLargeBisectionVerification(t *testing.T) { - numBlocks := int64(300) - mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, numBlocks, 101, 2, bTime) - - lastBlock := &types.LightBlock{SignedHeader: mockHeaders[numBlocks], ValidatorSet: mockVals[numBlocks]} - mockNode := &provider_mocks.Provider{} - mockNode.On("LightBlock", mock.Anything, numBlocks). - Return(lastBlock, nil) - - mockNode.On("LightBlock", mock.Anything, int64(200)). - Return(&types.LightBlock{SignedHeader: mockHeaders[200], ValidatorSet: mockVals[200]}, nil) - - mockNode.On("LightBlock", mock.Anything, int64(256)). - Return(&types.LightBlock{SignedHeader: mockHeaders[256], ValidatorSet: mockVals[256]}, nil) - - mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) - require.NoError(t, err) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ - Period: 4 * time.Hour, - Height: trustedLightBlock.Height, - Hash: trustedLightBlock.Hash(), - }, - mockNode, - []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), - ) - require.NoError(t, err) - h, err := c.Update(ctx, bTime.Add(300*time.Minute)) - assert.NoError(t, err) - height, err := c.LastTrustedHeight() - require.NoError(t, err) - require.Equal(t, numBlocks, height) - h2, err := mockNode.LightBlock(ctx, numBlocks) - require.NoError(t, err) - assert.Equal(t, h, h2) - mockNode.AssertExpectations(t) -} - -func TestClientBisectionBetweenTrustedHeaders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ + h1 = keys.GenSignedHeader(t, chainID, 1, bTime, nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) + // 3/3 signed + h2 = keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h1.Hash()}) + // 3/3 signed + h3 = keys.GenSignedHeaderLastBlockID(t, chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h2.Hash()}) + trustOptions = light.TrustOptions{ Period: 4 * time.Hour, Height: 1, Hash: h1.Hash(), - }, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), + } + headerSet = map[int64]*types.SignedHeader{ + 1: h1, + // interim header (3/3 signed) + 2: h2, + // last header (3/3 signed) + 3: h3, + } + l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} + l2 = &types.LightBlock{SignedHeader: h2, ValidatorSet: vals} + l3 = &types.LightBlock{SignedHeader: h3, ValidatorSet: vals} ) - require.NoError(t, err) + t.Run("ValidateTrustOptions", func(t *testing.T) { + testCases := []struct { + err bool + to light.TrustOptions + }{ + { + false, + trustOptions, + }, + { + true, + light.TrustOptions{ + Period: -1 * time.Hour, + Height: 1, + Hash: h1.Hash(), + }, + }, + { + true, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 0, + Hash: h1.Hash(), + }, + }, + { + true, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 1, + Hash: []byte("incorrect hash"), + }, + }, + } - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - require.NoError(t, err) + for idx, tc := range testCases { + t.Run(fmt.Sprint(idx), func(t *testing.T) { + err := tc.to.ValidateBasic() + if tc.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } + }) + t.Run("SequentialVerification", func(t *testing.T) { + newKeys := genPrivKeys(4) + newVals := newKeys.ToValidators(10, 1) + differentVals, _ := factory.ValidatorSet(t, 10, 100) - // confirm that the client already doesn't have the light block - _, err = c.TrustedLightBlock(2) - require.Error(t, err) + testCases := []struct { + name string + otherHeaders map[int64]*types.SignedHeader // all except ^ + vals map[int64]*types.ValidatorSet + initErr bool + verifyErr bool + }{ + { + name: "good", + otherHeaders: headerSet, + vals: valSet, + initErr: false, + verifyErr: false, + }, + { + "bad: different first header", + map[int64]*types.SignedHeader{ + // different header + 1: keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + }, + true, + false, + }, + { + "bad: no first signed header", + map[int64]*types.SignedHeader{}, + map[int64]*types.ValidatorSet{ + 1: differentVals, + }, + true, + true, + }, + { + "bad: different first validator set", + map[int64]*types.SignedHeader{ + 1: h1, + }, + map[int64]*types.ValidatorSet{ + 1: differentVals, + }, + true, + true, + }, + { + "bad: 1/3 signed interim header", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (1/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), + // last header (3/3 signed) + 3: keys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + }, + valSet, + false, + true, + }, + { + "bad: 1/3 signed last header", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (3/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + // last header (1/3 signed) + 3: keys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), + }, + valSet, + false, + true, + }, + { + "bad: different validator set at height 3", + headerSet, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + true, + }, + } - // verify using bisection the light block between the two trusted light blocks - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) - assert.NoError(t, err) - mockFullNode.AssertExpectations(t) -} + for _, tc := range testCases { + testCase := tc + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func TestClient_Cleanup(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewTestingLogger(t) - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - _, err = c.TrustedLightBlock(1) - require.NoError(t, err) + mockNode := mockNodeFromHeadersAndVals(testCase.otherHeaders, testCase.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.SequentialVerification(), + light.Logger(logger), + ) - err = c.Cleanup() - require.NoError(t, err) + if testCase.initErr { + require.Error(t, err) + return + } - // Check no light blocks exist after Cleanup. - l, err := c.TrustedLightBlock(1) - assert.Error(t, err) - assert.Nil(t, l) - mockFullNode.AssertExpectations(t) -} + require.NoError(t, err) -// trustedHeader.Height == options.Height -func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { - bctx, bcancel := context.WithCancel(context.Background()) - defer bcancel() + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) + if testCase.verifyErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + mockNode.AssertExpectations(t) + }) + } - // 1. options.Hash == trustedHeader.Hash - t.Run("hashes should match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) + }) + t.Run("SkippingVerification", func(t *testing.T) { + // required for 2nd test case + newKeys := genPrivKeys(4) + newVals := newKeys.ToValidators(10, 1) + + // 1/3+ of vals, 2/3- of newVals + transitKeys := keys.Extend(3) + transitVals := transitKeys.ToValidators(10, 1) + + testCases := []struct { + name string + otherHeaders map[int64]*types.SignedHeader // all except ^ + vals map[int64]*types.ValidatorSet + initErr bool + verifyErr bool + }{ + { + "good", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // last header (3/3 signed) + 3: h3, + }, + valSet, + false, + false, + }, + { + "good, but val set changes by 2/3 (1/3 of vals is still present)", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + 3: transitKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, transitVals, transitVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(transitKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: transitVals, + }, + false, + false, + }, + { + "good, but val set changes 100% at height 2", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (3/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + // last header (0/4 of the original val set signed) + 3: newKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + false, + }, + { + "bad: last header signed by newVals, interim header has no signers", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // last header (0/4 of the original val set signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, 0), + // last header (0/4 of the original val set signed) + 3: newKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + true, + }, + } + + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + logger := log.NewTestingLogger(t) + + mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), + light.Logger(logger), + ) + if tc.initErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) + if tc.verifyErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } + + }) + t.Run("LargeBisectionVerification", func(t *testing.T) { + // start from a large light block to make sure that the pivot height doesn't select a height outside + // the appropriate range + + numBlocks := int64(300) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(t, chainID, numBlocks, 101, 2, bTime) + + lastBlock := &types.LightBlock{SignedHeader: mockHeaders[numBlocks], ValidatorSet: mockVals[numBlocks]} + mockNode := &provider_mocks.Provider{} + mockNode.On("LightBlock", mock.Anything, numBlocks). + Return(lastBlock, nil) + + mockNode.On("LightBlock", mock.Anything, int64(200)). + Return(&types.LightBlock{SignedHeader: mockHeaders[200], ValidatorSet: mockVals[200]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(256)). + Return(&types.LightBlock{SignedHeader: mockHeaders[256], ValidatorSet: mockVals[256]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) + + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) - - mockNode := &provider_mocks.Provider{} - trustedStore := dbs.New(dbm.NewMemDB()) - err := trustedStore.SaveLightBlock(l1) + trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) require.NoError(t, err) - c, err := light.NewClient( ctx, chainID, - trustOptions, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: trustedLightBlock.Height, + Hash: trustedLightBlock.Hash(), + }, mockNode, []provider.Provider{mockNode}, - trustedStore, - light.Logger(logger), + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), ) require.NoError(t, err) - - l, err := c.TrustedLightBlock(1) + h, err := c.Update(ctx, bTime.Add(300*time.Minute)) assert.NoError(t, err) - assert.NotNil(t, l) - assert.Equal(t, l.Hash(), h1.Hash()) - assert.Equal(t, l.ValidatorSet.Hash(), h1.ValidatorsHash.Bytes()) + height, err := c.LastTrustedHeight() + require.NoError(t, err) + require.Equal(t, numBlocks, height) + h2, err := mockNode.LightBlock(ctx, numBlocks) + require.NoError(t, err) + assert.Equal(t, h, h2) mockNode.AssertExpectations(t) }) - - // 2. options.Hash != trustedHeader.Hash - t.Run("hashes should not match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) + t.Run("BisectionBetweenTrustedHeaders", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - trustedStore := dbs.New(dbm.NewMemDB()) - err := trustedStore.SaveLightBlock(l1) - require.NoError(t, err) - - logger := log.NewTestingLogger(t) - - // header1 != h1 - header1 := keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) - mockNode := &provider_mocks.Provider{} - + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) c, err := light.NewClient( ctx, chainID, light.TrustOptions{ Period: 4 * time.Hour, Height: 1, - Hash: header1.Hash(), + Hash: h1.Hash(), }, - mockNode, - []provider.Provider{mockNode}, - trustedStore, - light.Logger(logger), + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), ) require.NoError(t, err) - l, err := c.TrustedLightBlock(1) + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + require.NoError(t, err) + + // confirm that the client already doesn't have the light block + _, err = c.TrustedLightBlock(2) + require.Error(t, err) + + // verify using bisection the light block between the two trusted light blocks + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) assert.NoError(t, err) - if assert.NotNil(t, l) { - // client take the trusted store and ignores the trusted options - assert.Equal(t, l.Hash(), l1.Hash()) - assert.NoError(t, l.ValidateBasic(chainID)) - } - mockNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) }) -} - -func TestClient_Update(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(3)).Return(l3, nil) - - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - - // should result in downloading & verifying header #3 - l, err := c.Update(ctx, bTime.Add(2*time.Hour)) - assert.NoError(t, err) - if assert.NotNil(t, l) { - assert.EqualValues(t, 3, l.Height) - assert.NoError(t, l.ValidateBasic(chainID)) - } - mockFullNode.AssertExpectations(t) -} - -func TestClient_Concurrency(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(2)).Return(l2, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - // NOTE: Cleanup, Stop, VerifyLightBlockAtHeight and Verify are not supposed - // to be concurrently safe. - - assert.Equal(t, chainID, c.ChainID()) - - _, err := c.LastTrustedHeight() - assert.NoError(t, err) - - _, err = c.FirstTrustedHeight() - assert.NoError(t, err) - - l, err := c.TrustedLightBlock(1) - assert.NoError(t, err) - assert.NotNil(t, l) - }() - } - - wg.Wait() - mockFullNode.AssertExpectations(t) -} - -func TestClient_AddProviders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - }, valSet) - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - - closeCh := make(chan struct{}) - go func() { - // run verification concurrently to make sure it doesn't dead lock - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) - require.NoError(t, err) - close(closeCh) - }() - - // NOTE: the light client doesn't check uniqueness of providers - c.AddProvider(mockFullNode) - require.Len(t, c.Witnesses(), 2) - select { - case <-closeCh: - case <-time.After(5 * time.Second): - t.Fatal("concurent light block verification failed to finish in 5s") - } - mockFullNode.AssertExpectations(t) -} - -func TestClientReplacesPrimaryWithWitnessIfPrimaryIsUnavailable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) - - mockDeadNode := &provider_mocks.Provider{} - mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) - - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockDeadNode, - []provider.Provider{mockDeadNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - - require.NoError(t, err) - _, err = c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - // the primary should no longer be the deadNode - assert.NotEqual(t, c.Primary(), mockDeadNode) - - // we should still have the dead node as a witness because it - // hasn't repeatedly been unresponsive yet - assert.Equal(t, 2, len(c.Witnesses())) - mockDeadNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} - -func TestClientReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) - - logger := log.NewTestingLogger(t) - - mockDeadNode := &provider_mocks.Provider{} - mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockDeadNode, - []provider.Provider{mockDeadNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - _, err = c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - // we should still have the dead node as a witness because it - // hasn't repeatedly been unresponsive yet - assert.Equal(t, 2, len(c.Witnesses())) - mockDeadNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} - -func TestClient_BackwardsVerification(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - { - headers, vals, _ := genLightBlocksWithKeys(chainID, 9, 3, 0, bTime) - delete(headers, 1) - delete(headers, 2) - delete(vals, 1) - delete(vals, 2) - mockLargeFullNode := mockNodeFromHeadersAndVals(headers, vals) - trustHeader, _ := mockLargeFullNode.LightBlock(ctx, 6) + t.Run("Cleanup", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) c, err := light.NewClient( ctx, chainID, - light.TrustOptions{ - Period: 4 * time.Minute, - Height: trustHeader.Height, - Hash: trustHeader.Hash(), - }, - mockLargeFullNode, - []provider.Provider{mockLargeFullNode}, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(logger), ) require.NoError(t, err) - - // 1) verify before the trusted header using backwards => expect no error - h, err := c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) - require.NoError(t, err) - if assert.NotNil(t, h) { - assert.EqualValues(t, 5, h.Height) - } - - // 2) untrusted header is expired but trusted header is not => expect no error - h, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(8*time.Minute)) - assert.NoError(t, err) - assert.NotNil(t, h) - - // 3) already stored headers should return the header without error - h, err = c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) - assert.NoError(t, err) - assert.NotNil(t, h) - - // 4a) First verify latest header - _, err = c.VerifyLightBlockAtHeight(ctx, 9, bTime.Add(9*time.Minute)) + _, err = c.TrustedLightBlock(1) require.NoError(t, err) - // 4b) Verify backwards using bisection => expect no error - _, err = c.VerifyLightBlockAtHeight(ctx, 7, bTime.Add(9*time.Minute)) - assert.NoError(t, err) - // shouldn't have verified this header in the process - _, err = c.TrustedLightBlock(8) - assert.Error(t, err) - - // 5) Try bisection method, but closest header (at 7) has expired - // so expect error - _, err = c.VerifyLightBlockAtHeight(ctx, 8, bTime.Add(12*time.Minute)) - assert.Error(t, err) - mockLargeFullNode.AssertExpectations(t) - - } - { - // 8) provides incorrect hash - headers := map[int64]*types.SignedHeader{ - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), - 3: h3, - } - vals := valSet - mockNode := mockNodeFromHeadersAndVals(headers, vals) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 3, - Hash: h3.Hash(), - }, - mockNode, - []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) + err = c.Cleanup() require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour).Add(1*time.Second)) + // Check no light blocks exist after Cleanup. + l, err := c.TrustedLightBlock(1) assert.Error(t, err) - mockNode.AssertExpectations(t) - } -} + assert.Nil(t, l) + mockFullNode.AssertExpectations(t) + }) + t.Run("RestoresTrustedHeaderAfterStartup", func(t *testing.T) { + // trustedHeader.Height == options.Height -func TestClient_NewClientFromTrustedStore(t *testing.T) { - // 1) Initiate DB and fill with a "trusted" header - db := dbs.New(dbm.NewMemDB()) - err := db.SaveLightBlock(l1) - require.NoError(t, err) - mockNode := &provider_mocks.Provider{} + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() - c, err := light.NewClientFromTrustedStore( - chainID, - trustPeriod, - mockNode, - []provider.Provider{mockNode}, - db, - ) - require.NoError(t, err) - - // 2) Check light block exists - h, err := c.TrustedLightBlock(1) - assert.NoError(t, err) - assert.EqualValues(t, l1.Height, h.Height) - mockNode.AssertExpectations(t) -} - -func TestClientRemovesWitnessIfItSendsUsIncorrectHeader(t *testing.T) { - logger := log.NewTestingLogger(t) - - // different headers hash then primary plus less than 1/3 signed (no fork) - headers1 := map[int64]*types.SignedHeader{ - 1: h1, - 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash"), hash("results_hash"), - len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), - } - vals1 := map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - } - mockBadNode1 := mockNodeFromHeadersAndVals(headers1, vals1) - mockBadNode1.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - - // header is empty - headers2 := map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - } - vals2 := map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - } - mockBadNode2 := mockNodeFromHeadersAndVals(headers2, vals2) - mockBadNode2.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - - mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - lb1, _ := mockBadNode1.LightBlock(ctx, 2) - require.NotEqual(t, lb1.Hash(), l1.Hash()) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockBadNode1, mockBadNode2}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - // witness should have behaved properly -> no error - require.NoError(t, err) - assert.EqualValues(t, 2, len(c.Witnesses())) - - // witness behaves incorrectly -> removed from list, no error - l, err := c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(c.Witnesses())) - // light block should still be verified - assert.EqualValues(t, 2, l.Height) - - // remaining witnesses don't have light block -> error - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - if assert.Error(t, err) { - assert.Equal(t, light.ErrFailedHeaderCrossReferencing, err) - } - // witness does not have a light block -> left in the list - assert.EqualValues(t, 1, len(c.Witnesses())) - mockBadNode1.AssertExpectations(t) - mockBadNode2.AssertExpectations(t) -} - -func TestClient_TrustedValidatorSet(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - logger := log.NewTestingLogger(t) - - differentVals, _ := factory.ValidatorSet(t, 10, 100) - mockBadValSetNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ - 1: h1, - // 3/3 signed, but validator set at height 2 below is invalid -> witness - // should be removed. - 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash"), hash("results_hash"), - 0, len(keys), types.BlockID{Hash: h1.Hash()}), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: differentVals, - }) - mockFullNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - }) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockBadValSetNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - assert.Equal(t, 2, len(c.Witnesses())) - - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour).Add(1*time.Second)) - assert.NoError(t, err) - assert.Equal(t, 1, len(c.Witnesses())) - mockBadValSetNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} - -func TestClientPrunesHeadersAndValidatorSets(t *testing.T) { - mockFullNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ - 1: h1, - 3: h3, - 0: h3, - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 3: vals, - 0: vals, - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - light.PruningSize(1), - ) - require.NoError(t, err) - _, err = c.TrustedLightBlock(1) - require.NoError(t, err) - - h, err := c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - require.Equal(t, int64(3), h.Height) - - _, err = c.TrustedLightBlock(1) - assert.Error(t, err) - mockFullNode.AssertExpectations(t) -} - -func TestClientEnsureValidHeadersAndValSets(t *testing.T) { - emptyValSet := &types.ValidatorSet{ - Validators: nil, - Proposer: nil, - } - - testCases := []struct { - headers map[int64]*types.SignedHeader - vals map[int64]*types.ValidatorSet - - errorToThrow error - errorHeight int64 - - err bool - }{ - { - headers: map[int64]*types.SignedHeader{ - 1: h1, - 3: h3, - }, - vals: map[int64]*types.ValidatorSet{ - 1: vals, - 3: vals, - }, - err: false, - }, - { - headers: map[int64]*types.SignedHeader{ - 1: h1, - }, - vals: map[int64]*types.ValidatorSet{ - 1: vals, - }, - errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, - errorHeight: 3, - err: true, - }, - { - headers: map[int64]*types.SignedHeader{ - 1: h1, - }, - errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, - errorHeight: 3, - vals: valSet, - err: true, - }, - { - headers: map[int64]*types.SignedHeader{ - 1: h1, - 3: h3, - }, - vals: map[int64]*types.ValidatorSet{ - 1: vals, - 3: emptyValSet, - }, - err: true, - }, - } - - for i, tc := range testCases { - testCase := tc - t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + // 1. options.Hash == trustedHeader.Hash + t.Run("hashes should match", func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) defer cancel() - mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) - if testCase.errorToThrow != nil { - mockBadNode.On("LightBlock", mock.Anything, testCase.errorHeight).Return(nil, testCase.errorToThrow) - } + logger := log.NewTestingLogger(t) + + mockNode := &provider_mocks.Provider{} + trustedStore := dbs.New(dbm.NewMemDB()) + err := trustedStore.SaveLightBlock(l1) + require.NoError(t, err) c, err := light.NewClient( ctx, chainID, trustOptions, - mockBadNode, - []provider.Provider{mockBadNode, mockBadNode}, - dbs.New(dbm.NewMemDB()), + mockNode, + []provider.Provider{mockNode}, + trustedStore, + light.Logger(logger), ) require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - if testCase.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - mockBadNode.AssertExpectations(t) + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.NotNil(t, l) + assert.Equal(t, l.Hash(), h1.Hash()) + assert.Equal(t, l.ValidatorSet.Hash(), h1.ValidatorsHash.Bytes()) + mockNode.AssertExpectations(t) }) - } + + // 2. options.Hash != trustedHeader.Hash + t.Run("hashes should not match", func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + + trustedStore := dbs.New(dbm.NewMemDB()) + err := trustedStore.SaveLightBlock(l1) + require.NoError(t, err) + + logger := log.NewTestingLogger(t) + + // header1 != h1 + header1 := keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) + mockNode := &provider_mocks.Provider{} + + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: 1, + Hash: header1.Hash(), + }, + mockNode, + []provider.Provider{mockNode}, + trustedStore, + light.Logger(logger), + ) + require.NoError(t, err) + + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + if assert.NotNil(t, l) { + // client take the trusted store and ignores the trusted options + assert.Equal(t, l.Hash(), l1.Hash()) + assert.NoError(t, l.ValidateBasic(chainID)) + } + mockNode.AssertExpectations(t) + }) + }) + t.Run("Update", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(3)).Return(l3, nil) + + logger := log.NewTestingLogger(t) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + + // should result in downloading & verifying header #3 + l, err := c.Update(ctx, bTime.Add(2*time.Hour)) + assert.NoError(t, err) + if assert.NotNil(t, l) { + assert.EqualValues(t, 3, l.Height) + assert.NoError(t, l.ValidateBasic(chainID)) + } + mockFullNode.AssertExpectations(t) + }) + + t.Run("Concurrency", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) + + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(2)).Return(l2, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + require.NoError(t, err) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // NOTE: Cleanup, Stop, VerifyLightBlockAtHeight and Verify are not supposed + // to be concurrently safe. + + assert.Equal(t, chainID, c.ChainID()) + + _, err := c.LastTrustedHeight() + assert.NoError(t, err) + + _, err = c.FirstTrustedHeight() + assert.NoError(t, err) + + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.NotNil(t, l) + }() + } + + wg.Wait() + mockFullNode.AssertExpectations(t) + }) + t.Run("AddProviders", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, + }, valSet) + logger := log.NewTestingLogger(t) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + + closeCh := make(chan struct{}) + go func() { + // run verification concurrently to make sure it doesn't dead lock + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + require.NoError(t, err) + close(closeCh) + }() + + // NOTE: the light client doesn't check uniqueness of providers + c.AddProvider(mockFullNode) + require.Len(t, c.Witnesses(), 2) + select { + case <-closeCh: + case <-time.After(5 * time.Second): + t.Fatal("concurent light block verification failed to finish in 5s") + } + mockFullNode.AssertExpectations(t) + }) + t.Run("ReplacesPrimaryWithWitnessIfPrimaryIsUnavailable", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) + + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) + + logger := log.NewTestingLogger(t) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockDeadNode, + []provider.Provider{mockDeadNode, mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + + require.NoError(t, err) + _, err = c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) + + // the primary should no longer be the deadNode + assert.NotEqual(t, c.Primary(), mockDeadNode) + + // we should still have the dead node as a witness because it + // hasn't repeatedly been unresponsive yet + assert.Equal(t, 2, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("ReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) + + logger := log.NewTestingLogger(t) + + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockDeadNode, + []provider.Provider{mockDeadNode, mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + _, err = c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) + + // we should still have the dead node as a witness because it + // hasn't repeatedly been unresponsive yet + assert.Equal(t, 2, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("BackwardsVerification", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) + + { + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 9, 3, 0, bTime) + delete(headers, 1) + delete(headers, 2) + delete(vals, 1) + delete(vals, 2) + mockLargeFullNode := mockNodeFromHeadersAndVals(headers, vals) + trustHeader, _ := mockLargeFullNode.LightBlock(ctx, 6) + + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Minute, + Height: trustHeader.Height, + Hash: trustHeader.Hash(), + }, + mockLargeFullNode, + []provider.Provider{mockLargeFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + + // 1) verify before the trusted header using backwards => expect no error + h, err := c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) + require.NoError(t, err) + if assert.NotNil(t, h) { + assert.EqualValues(t, 5, h.Height) + } + + // 2) untrusted header is expired but trusted header is not => expect no error + h, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(8*time.Minute)) + assert.NoError(t, err) + assert.NotNil(t, h) + + // 3) already stored headers should return the header without error + h, err = c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) + assert.NoError(t, err) + assert.NotNil(t, h) + + // 4a) First verify latest header + _, err = c.VerifyLightBlockAtHeight(ctx, 9, bTime.Add(9*time.Minute)) + require.NoError(t, err) + + // 4b) Verify backwards using bisection => expect no error + _, err = c.VerifyLightBlockAtHeight(ctx, 7, bTime.Add(9*time.Minute)) + assert.NoError(t, err) + // shouldn't have verified this header in the process + _, err = c.TrustedLightBlock(8) + assert.Error(t, err) + + // 5) Try bisection method, but closest header (at 7) has expired + // so expect error + _, err = c.VerifyLightBlockAtHeight(ctx, 8, bTime.Add(12*time.Minute)) + assert.Error(t, err) + mockLargeFullNode.AssertExpectations(t) + + } + { + // 8) provides incorrect hash + headers := map[int64]*types.SignedHeader{ + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), + 3: h3, + } + vals := valSet + mockNode := mockNodeFromHeadersAndVals(headers, vals) + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 3, + Hash: h3.Hash(), + }, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour).Add(1*time.Second)) + assert.Error(t, err) + mockNode.AssertExpectations(t) + } + }) + t.Run("NewClientFromTrustedStore", func(t *testing.T) { + // 1) Initiate DB and fill with a "trusted" header + db := dbs.New(dbm.NewMemDB()) + err := db.SaveLightBlock(l1) + require.NoError(t, err) + mockNode := &provider_mocks.Provider{} + + c, err := light.NewClientFromTrustedStore( + chainID, + trustPeriod, + mockNode, + []provider.Provider{mockNode}, + db, + ) + require.NoError(t, err) + + // 2) Check light block exists + h, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.EqualValues(t, l1.Height, h.Height) + mockNode.AssertExpectations(t) + }) + t.Run("RemovesWitnessIfItSendsUsIncorrectHeader", func(t *testing.T) { + logger := log.NewTestingLogger(t) + + // different headers hash then primary plus less than 1/3 signed (no fork) + headers1 := map[int64]*types.SignedHeader{ + 1: h1, + 2: keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash"), hash("results_hash"), + len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), + } + vals1 := map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + } + mockBadNode1 := mockNodeFromHeadersAndVals(headers1, vals1) + mockBadNode1.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + + // header is empty + headers2 := map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, + } + vals2 := map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + } + mockBadNode2 := mockNodeFromHeadersAndVals(headers2, vals2) + mockBadNode2.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + lb1, _ := mockBadNode1.LightBlock(ctx, 2) + require.NotEqual(t, lb1.Hash(), l1.Hash()) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockBadNode1, mockBadNode2}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + // witness should have behaved properly -> no error + require.NoError(t, err) + assert.EqualValues(t, 2, len(c.Witnesses())) + + // witness behaves incorrectly -> removed from list, no error + l, err := c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(c.Witnesses())) + // light block should still be verified + assert.EqualValues(t, 2, l.Height) + + // remaining witnesses don't have light block -> error + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + if assert.Error(t, err) { + assert.Equal(t, light.ErrFailedHeaderCrossReferencing, err) + } + // witness does not have a light block -> left in the list + assert.EqualValues(t, 1, len(c.Witnesses())) + mockBadNode1.AssertExpectations(t) + mockBadNode2.AssertExpectations(t) + }) + t.Run("TrustedValidatorSet", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewTestingLogger(t) + + differentVals, _ := factory.ValidatorSet(t, 10, 100) + mockBadValSetNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ + 1: h1, + // 3/3 signed, but validator set at height 2 below is invalid -> witness + // should be removed. + 2: keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash"), hash("results_hash"), + 0, len(keys), types.BlockID{Hash: h1.Hash()}), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: differentVals, + }) + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + }) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockBadValSetNode, mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + assert.Equal(t, 2, len(c.Witnesses())) + + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour).Add(1*time.Second)) + assert.NoError(t, err) + assert.Equal(t, 1, len(c.Witnesses())) + mockBadValSetNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("PrunesHeadersAndValidatorSets", func(t *testing.T) { + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + 0: h3, + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 3: vals, + 0: vals, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + light.PruningSize(1), + ) + require.NoError(t, err) + _, err = c.TrustedLightBlock(1) + require.NoError(t, err) + + h, err := c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) + require.Equal(t, int64(3), h.Height) + + _, err = c.TrustedLightBlock(1) + assert.Error(t, err) + mockFullNode.AssertExpectations(t) + }) + t.Run("EnsureValidHeadersAndValSets", func(t *testing.T) { + emptyValSet := &types.ValidatorSet{ + Validators: nil, + Proposer: nil, + } + + testCases := []struct { + headers map[int64]*types.SignedHeader + vals map[int64]*types.ValidatorSet + + errorToThrow error + errorHeight int64 + + err bool + }{ + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + 3: vals, + }, + err: false, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + vals: valSet, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + 3: emptyValSet, + }, + err: true, + }, + } + + for i, tc := range testCases { + testCase := tc + t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) + if testCase.errorToThrow != nil { + mockBadNode.On("LightBlock", mock.Anything, testCase.errorHeight).Return(nil, testCase.errorToThrow) + } + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockBadNode, + []provider.Provider{mockBadNode, mockBadNode}, + dbs.New(dbm.NewMemDB()), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + if testCase.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + mockBadNode.AssertExpectations(t) + }) + } + }) } diff --git a/light/detector_test.go b/light/detector_test.go index f61d7f116..84b6f210c 100644 --- a/light/detector_test.go +++ b/light/detector_test.go @@ -35,7 +35,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, latestHeight, valSize, 2, bTime) forgedKeys := chainKeys[divergenceHeight-1].ChangeKeys(3) // we change 3 out of the 5 validators (still 2/5 remain) forgedVals := forgedKeys.ToValidators(2, 0) @@ -46,7 +46,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { primaryValidators[height] = witnessValidators[height] continue } - primaryHeaders[height] = forgedKeys.GenSignedHeader(chainID, height, bTime.Add(time.Duration(height)*time.Minute), + primaryHeaders[height] = forgedKeys.GenSignedHeader(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), nil, forgedVals, forgedVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(forgedKeys)) primaryValidators[height] = forgedVals } @@ -152,7 +152,7 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { // validators don't change in this network (however we still use a map just for convenience) primaryValidators = make(map[int64]*types.ValidatorSet, testCase.latestHeight) ) - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, testCase.latestHeight+1, valSize, 2, bTime) for height := int64(1); height <= testCase.latestHeight; height++ { if height < testCase.divergenceHeight { @@ -162,7 +162,7 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { } // we don't have a network partition so we will make 4/5 (greater than 2/3) malicious and vote again for // a different block (which we do by adding txs) - primaryHeaders[height] = chainKeys[height].GenSignedHeader(chainID, height, + primaryHeaders[height] = chainKeys[height].GenSignedHeader(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), []types.Tx{[]byte("abcd")}, witnessValidators[height], witnessValidators[height+1], hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(chainKeys[height])-1) @@ -246,7 +246,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { defer cancel() logger := log.NewTestingLogger(t) - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, latestHeight, valSize, 2, bTime) for _, unusedHeader := range []int64{3, 5, 6, 8} { delete(witnessHeaders, unusedHeader) } @@ -262,7 +262,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { } forgedKeys := chainKeys[latestHeight].ChangeKeys(3) // we change 3 out of the 5 validators (still 2/5 remain) primaryValidators[forgedHeight] = forgedKeys.ToValidators(2, 0) - primaryHeaders[forgedHeight] = forgedKeys.GenSignedHeader( + primaryHeaders[forgedHeight] = forgedKeys.GenSignedHeader(t, chainID, forgedHeight, bTime.Add(time.Duration(latestHeight+1)*time.Minute), // 11 mins @@ -326,7 +326,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { // to prove that there was an attack vals := chainKeys[latestHeight].ToValidators(2, 0) newLb := &types.LightBlock{ - SignedHeader: chainKeys[latestHeight].GenSignedHeader( + SignedHeader: chainKeys[latestHeight].GenSignedHeader(t, chainID, proofHeight, bTime.Add(time.Duration(proofHeight+1)*time.Minute), // 12 mins @@ -395,11 +395,11 @@ func TestClientDivergentTraces1(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 1, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(headers, vals) firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - headers, vals, _ = genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + headers, vals, _ = genLightBlocksWithKeys(t, chainID, 1, 5, 2, bTime) mockWitness := mockNodeFromHeadersAndVals(headers, vals) logger := log.NewTestingLogger(t) @@ -430,7 +430,7 @@ func TestClientDivergentTraces2(t *testing.T) { defer cancel() logger := log.NewTestingLogger(t) - headers, vals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimaryNode := mockNodeFromHeadersAndVals(headers, vals) mockDeadNode := &provider_mocks.Provider{} mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) @@ -465,7 +465,7 @@ func TestClientDivergentTraces3(t *testing.T) { logger := log.NewTestingLogger(t) // - primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) ctx, cancel := context.WithCancel(context.Background()) @@ -474,7 +474,7 @@ func TestClientDivergentTraces3(t *testing.T) { firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockHeaders[1] = primaryHeaders[1] mockVals[1] = primaryVals[1] mockWitness := mockNodeFromHeadersAndVals(mockHeaders, mockVals) @@ -508,7 +508,7 @@ func TestClientDivergentTraces4(t *testing.T) { logger := log.NewTestingLogger(t) // - primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) ctx, cancel := context.WithCancel(context.Background()) @@ -517,7 +517,7 @@ func TestClientDivergentTraces4(t *testing.T) { firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - witnessHeaders, witnessVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + witnessHeaders, witnessVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) primaryHeaders[2] = witnessHeaders[2] primaryVals[2] = witnessVals[2] mockWitness := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) diff --git a/light/helpers_test.go b/light/helpers_test.go index 1d25f9166..9f6147526 100644 --- a/light/helpers_test.go +++ b/light/helpers_test.go @@ -1,9 +1,11 @@ package light_test import ( + "testing" "time" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/tmhash" @@ -74,7 +76,9 @@ func (pkz privKeys) ToValidators(init, inc int64) *types.ValidatorSet { } // signHeader properly signs the header with all keys from first to last exclusive. -func (pkz privKeys) signHeader(header *types.Header, valSet *types.ValidatorSet, first, last int) *types.Commit { +func (pkz privKeys) signHeader(t testing.TB, header *types.Header, valSet *types.ValidatorSet, first, last int) *types.Commit { + t.Helper() + commitSigs := make([]types.CommitSig, len(pkz)) for i := 0; i < len(pkz); i++ { commitSigs[i] = types.NewCommitSigAbsent() @@ -87,15 +91,15 @@ func (pkz privKeys) signHeader(header *types.Header, valSet *types.ValidatorSet, // Fill in the votes we want. for i := first; i < last && i < len(pkz); i++ { - vote := makeVote(header, valSet, pkz[i], blockID) + vote := makeVote(t, header, valSet, pkz[i], blockID) commitSigs[vote.ValidatorIndex] = vote.CommitSig() } return types.NewCommit(header.Height, 1, blockID, commitSigs) } -func makeVote(header *types.Header, valset *types.ValidatorSet, - key crypto.PrivKey, blockID types.BlockID) *types.Vote { +func makeVote(t testing.TB, header *types.Header, valset *types.ValidatorSet, key crypto.PrivKey, blockID types.BlockID) *types.Vote { + t.Helper() addr := key.PubKey().Address() idx, _ := valset.GetByAddress(addr) @@ -113,9 +117,7 @@ func makeVote(header *types.Header, valset *types.ValidatorSet, // Sign it signBytes := types.VoteSignBytes(header.ChainID, v) sig, err := key.Sign(signBytes) - if err != nil { - panic(err) - } + require.NoError(t, err) vote.Signature = sig @@ -143,26 +145,30 @@ func genHeader(chainID string, height int64, bTime time.Time, txs types.Txs, } // GenSignedHeader calls genHeader and signHeader and combines them into a SignedHeader. -func (pkz privKeys) GenSignedHeader(chainID string, height int64, bTime time.Time, txs types.Txs, +func (pkz privKeys) GenSignedHeader(t testing.TB, chainID string, height int64, bTime time.Time, txs types.Txs, valset, nextValset *types.ValidatorSet, appHash, consHash, resHash []byte, first, last int) *types.SignedHeader { + t.Helper() + header := genHeader(chainID, height, bTime, txs, valset, nextValset, appHash, consHash, resHash) return &types.SignedHeader{ Header: header, - Commit: pkz.signHeader(header, valset, first, last), + Commit: pkz.signHeader(t, header, valset, first, last), } } // GenSignedHeaderLastBlockID calls genHeader and signHeader and combines them into a SignedHeader. -func (pkz privKeys) GenSignedHeaderLastBlockID(chainID string, height int64, bTime time.Time, txs types.Txs, +func (pkz privKeys) GenSignedHeaderLastBlockID(t testing.TB, chainID string, height int64, bTime time.Time, txs types.Txs, valset, nextValset *types.ValidatorSet, appHash, consHash, resHash []byte, first, last int, lastBlockID types.BlockID) *types.SignedHeader { + t.Helper() + header := genHeader(chainID, height, bTime, txs, valset, nextValset, appHash, consHash, resHash) header.LastBlockID = lastBlockID return &types.SignedHeader{ Header: header, - Commit: pkz.signHeader(header, valset, first, last), + Commit: pkz.signHeader(t, header, valset, first, last), } } @@ -175,14 +181,14 @@ func (pkz privKeys) ChangeKeys(delta int) privKeys { // blocks to height. BlockIntervals are in per minute. // NOTE: Expected to have a large validator set size ~ 100 validators. func genLightBlocksWithKeys( + t testing.TB, chainID string, numBlocks int64, valSize int, valVariation float32, - bTime time.Time) ( - map[int64]*types.SignedHeader, - map[int64]*types.ValidatorSet, - map[int64]privKeys) { + bTime time.Time, +) (map[int64]*types.SignedHeader, map[int64]*types.ValidatorSet, map[int64]privKeys) { + t.Helper() var ( headers = make(map[int64]*types.SignedHeader, numBlocks) @@ -201,7 +207,7 @@ func genLightBlocksWithKeys( keymap[2] = newKeys // genesis header and vals - lastHeader := keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Minute), nil, + lastHeader := keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Minute), nil, keys.ToValidators(2, 0), newKeys.ToValidators(2, 0), hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) currentHeader := lastHeader @@ -214,7 +220,7 @@ func genLightBlocksWithKeys( valVariationInt = int(totalVariation) totalVariation = -float32(valVariationInt) newKeys = keys.ChangeKeys(valVariationInt) - currentHeader = keys.GenSignedHeaderLastBlockID(chainID, height, bTime.Add(time.Duration(height)*time.Minute), + currentHeader = keys.GenSignedHeaderLastBlockID(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), nil, keys.ToValidators(2, 0), newKeys.ToValidators(2, 0), hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: lastHeader.Hash()}) diff --git a/light/proxy/routes.go b/light/proxy/routes.go index 4561bf7f9..76cc52f73 100644 --- a/light/proxy/routes.go +++ b/light/proxy/routes.go @@ -16,9 +16,7 @@ type proxyService struct { *lrpc.Client } -func (p proxyService) ABCIQuery(ctx context.Context, path string, data tmbytes.HexBytes, - height int64, prove bool) (*coretypes.ResultABCIQuery, error) { - +func (p proxyService) ABCIQuery(ctx context.Context, path string, data tmbytes.HexBytes, height int64, prove bool) (*coretypes.ResultABCIQuery, error) { return p.ABCIQueryWithOptions(ctx, path, data, rpcclient.ABCIQueryOptions{ Height: height, Prove: prove, diff --git a/light/rpc/client.go b/light/rpc/client.go index 41ed97401..fec6e4723 100644 --- a/light/rpc/client.go +++ b/light/rpc/client.go @@ -565,7 +565,7 @@ func (c *Client) Validators( } skipCount := validateSkipCount(page, perPage) - v := l.ValidatorSet.Validators[skipCount : skipCount+tmmath.MinInt(perPage, totalCount-skipCount)] + v := l.ValidatorSet.Validators[skipCount : skipCount+tmmath.MinInt(int(perPage), totalCount-skipCount)] return &coretypes.ResultValidators{ BlockHeight: l.Height, @@ -672,16 +672,13 @@ const ( maxPerPage = 100 ) -func validatePage(pagePtr *int, perPage, totalCount int) (int, error) { - if perPage < 1 { - panic(fmt.Errorf("%w (%d)", coretypes.ErrZeroOrNegativePerPage, perPage)) - } +func validatePage(pagePtr *int, perPage uint, totalCount int) (int, error) { if pagePtr == nil { // no page parameter return 1, nil } - pages := ((totalCount - 1) / perPage) + 1 + pages := ((totalCount - 1) / int(perPage)) + 1 if pages == 0 { pages = 1 // one page (even if it's empty) } @@ -693,7 +690,7 @@ func validatePage(pagePtr *int, perPage, totalCount int) (int, error) { return page, nil } -func validatePerPage(perPagePtr *int) int { +func validatePerPage(perPagePtr *int) uint { if perPagePtr == nil { // no per_page parameter return defaultPerPage } @@ -704,11 +701,11 @@ func validatePerPage(perPagePtr *int) int { } else if perPage > maxPerPage { return maxPerPage } - return perPage + return uint(perPage) } -func validateSkipCount(page, perPage int) int { - skipCount := (page - 1) * perPage +func validateSkipCount(page int, perPage uint) int { + skipCount := (page - 1) * int(perPage) if skipCount < 0 { return 0 } diff --git a/light/verifier.go b/light/verifier.go index ee4bfb053..f6156c5de 100644 --- a/light/verifier.go +++ b/light/verifier.go @@ -38,9 +38,12 @@ func VerifyNonAdjacent( trustingPeriod time.Duration, now time.Time, maxClockDrift time.Duration, - trustLevel tmmath.Fraction) error { + trustLevel tmmath.Fraction, +) error { - checkRequiredHeaderFields(trustedHeader) + if err := checkRequiredHeaderFields(trustedHeader); err != nil { + return err + } if untrustedHeader.Height == trustedHeader.Height+1 { return errors.New("headers must be non adjacent in height") @@ -106,12 +109,15 @@ func VerifyAdjacent( untrustedVals *types.ValidatorSet, // height=X+1 trustingPeriod time.Duration, now time.Time, - maxClockDrift time.Duration) error { + maxClockDrift time.Duration, +) error { - checkRequiredHeaderFields(trustedHeader) + if err := checkRequiredHeaderFields(trustedHeader); err != nil { + return err + } if len(trustedHeader.NextValidatorsHash) == 0 { - panic("next validators hash in trusted header is empty") + return errors.New("next validators hash in trusted header is empty") } if untrustedHeader.Height != trustedHeader.Height+1 { @@ -268,17 +274,18 @@ func verifyNewHeaderAndVals( return nil } -func checkRequiredHeaderFields(h *types.SignedHeader) { +func checkRequiredHeaderFields(h *types.SignedHeader) error { if h.Height == 0 { - panic("height in trusted header must be set (non zero") + return errors.New("height in trusted header must be set (non zero") } zeroTime := time.Time{} if h.Time == zeroTime { - panic("time in trusted header must be set") + return errors.New("time in trusted header must be set") } if h.ChainID == "" { - panic("chain ID in trusted header must be set") + return errors.New("chain ID in trusted header must be set") } + return nil } diff --git a/light/verifier_test.go b/light/verifier_test.go index 0432c130d..5a2019e21 100644 --- a/light/verifier_test.go +++ b/light/verifier_test.go @@ -28,7 +28,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) ) @@ -51,7 +51,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // different chainID -> error 1: { - keys.GenSignedHeader("different-chainID", nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, "different-chainID", nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -61,7 +61,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is before old header's time -> error 2: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(-1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(-1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 4 * time.Hour, @@ -71,7 +71,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is from the future -> error 3: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(3*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(3*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -81,7 +81,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is from the future, but it's acceptable (< maxClockDrift) -> no error 4: { - keys.GenSignedHeader(chainID, nextHeight, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(2*time.Hour).Add(maxClockDrift).Add(-1*time.Millisecond), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, @@ -92,7 +92,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 3/3 signed -> no error 5: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -102,7 +102,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 2/3 signed -> no error 6: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 1, len(keys)), vals, 3 * time.Hour, @@ -112,7 +112,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 1/3 signed -> error 7: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), vals, 3 * time.Hour, @@ -122,7 +122,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // vals does not match with what we have -> error 8: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, keys.ToValidators(10, 1), vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, keys.ToValidators(10, 1), vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 3 * time.Hour, @@ -132,7 +132,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // vals are inconsistent with newHeader -> error 9: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 3 * time.Hour, @@ -142,7 +142,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // old header has expired -> error 10: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 1 * time.Hour, @@ -180,7 +180,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) // 30, 40, 50 @@ -206,7 +206,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }{ // 3/3 new vals signed, 3/3 old vals present -> no error 0: { - keys.GenSignedHeader(chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -216,7 +216,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 2/3 new vals signed, 3/3 old vals present -> no error 1: { - keys.GenSignedHeader(chainID, 4, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 4, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 1, len(keys)), vals, 3 * time.Hour, @@ -226,7 +226,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 1/3 new vals signed, 3/3 old vals present -> error 2: { - keys.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), vals, 3 * time.Hour, @@ -236,7 +236,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, 2/3 old vals present -> no error 3: { - twoThirds.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, twoThirdsVals, twoThirdsVals, + twoThirds.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, twoThirdsVals, twoThirdsVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(twoThirds)), twoThirdsVals, 3 * time.Hour, @@ -246,7 +246,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, 1/3 old vals present -> no error 4: { - oneThird.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, oneThirdVals, oneThirdVals, + oneThird.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, oneThirdVals, oneThirdVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(oneThird)), oneThirdVals, 3 * time.Hour, @@ -256,7 +256,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, less than 1/3 old vals present -> error 5: { - lessThanOneThird.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, lessThanOneThirdVals, lessThanOneThirdVals, + lessThanOneThird.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, lessThanOneThirdVals, lessThanOneThirdVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(lessThanOneThird)), lessThanOneThirdVals, 3 * time.Hour, @@ -296,7 +296,7 @@ func TestVerifyReturnsErrorIfTrustLevelIsInvalid(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) ) diff --git a/node/node_test.go b/node/node_test.go index 76bbd2966..74385ec33 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -47,7 +47,7 @@ func TestNodeStartStop(t *testing.T) { ctx, bcancel := context.WithCancel(context.Background()) defer bcancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() // create & start node ns, err := newDefaultNode(ctx, cfg, logger) require.NoError(t, err) @@ -98,6 +98,7 @@ func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger ns.Wait() } }) + t.Cleanup(leaktest.CheckTimeout(t, time.Second)) return n } @@ -112,7 +113,7 @@ func TestNodeDelayedStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() // create & start node n := getTestNode(ctx, t, cfg, logger) @@ -132,7 +133,7 @@ func TestNodeSetAppVersion(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() // create node n := getTestNode(ctx, t, cfg, logger) @@ -156,7 +157,7 @@ func TestNodeSetPrivValTCP(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() cfg, err := config.ResetTestRoot("node_priv_val_tcp_test") require.NoError(t, err) @@ -200,7 +201,7 @@ func TestPrivValidatorListenAddrNoProtocol(t *testing.T) { defer os.RemoveAll(cfg.RootDir) cfg.PrivValidator.ListenAddr = addrNoPrefix - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() n, err := newDefaultNode(ctx, cfg, logger) @@ -224,7 +225,7 @@ func TestNodeSetPrivValIPC(t *testing.T) { defer os.RemoveAll(cfg.RootDir) cfg.PrivValidator.ListenAddr = "unix://" + tmpfile - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() dialer := privval.DialUnixFn(tmpfile) dialerEndpoint := privval.NewSignerDialerEndpoint(logger, dialer) @@ -270,7 +271,7 @@ func TestCreateProposalBlock(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) @@ -367,7 +368,7 @@ func TestMaxTxsProposalBlockSize(t *testing.T) { defer os.RemoveAll(cfg.RootDir) - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) @@ -433,7 +434,7 @@ func TestMaxProposalBlockSize(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(cfg.RootDir) - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() cc := abciclient.NewLocalCreator(kvstore.NewApplication()) proxyApp := proxy.NewAppConns(cc, logger, proxy.NopMetrics()) @@ -554,7 +555,7 @@ func TestNodeNewSeedNode(t *testing.T) { nodeKey, err := types.LoadOrGenNodeKey(cfg.NodeKeyFile()) require.NoError(t, err) - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() ns, err := makeSeedNode(ctx, cfg, @@ -588,7 +589,7 @@ func TestNodeSetEventSink(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + logger := log.NewNopLogger() setupTest := func(t *testing.T, conf *config.Config) []indexer.EventSink { eventBus := eventbus.NewDefault(logger.With("module", "events")) diff --git a/rpc/jsonrpc/server/http_uri_handler.go b/rpc/jsonrpc/server/http_uri_handler.go index 808da16e2..0728b1f08 100644 --- a/rpc/jsonrpc/server/http_uri_handler.go +++ b/rpc/jsonrpc/server/http_uri_handler.go @@ -1,15 +1,16 @@ package server import ( + "context" "encoding/hex" + "encoding/json" "errors" "fmt" "net/http" "reflect" - "regexp" + "strconv" "strings" - tmjson "github.com/tendermint/tendermint/libs/json" "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/rpc/coretypes" rpctypes "github.com/tendermint/tendermint/rpc/jsonrpc/types" @@ -17,8 +18,6 @@ import ( // HTTP + URI handler -var reInt = regexp.MustCompile(`^-?[0-9]+$`) - // convert from a function name to the http handler func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWriter, *http.Request) { // Always return -1 as there's no ID here. @@ -35,22 +34,21 @@ func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWrit } // All other endpoints - return func(w http.ResponseWriter, r *http.Request) { - ctx := rpctypes.WithCallInfo(r.Context(), &rpctypes.CallInfo{HTTPRequest: r}) - args := []reflect.Value{reflect.ValueOf(ctx)} - - fnArgs, err := httpParamsToArgs(rpcFunc, r) + return func(w http.ResponseWriter, req *http.Request) { + ctx := rpctypes.WithCallInfo(req.Context(), &rpctypes.CallInfo{ + HTTPRequest: req, + }) + args, err := parseURLParams(ctx, rpcFunc, req) if err != nil { - writeHTTPResponse(w, logger, rpctypes.RPCInvalidParamsError( - dummyID, fmt.Errorf("error converting http params to arguments: %w", err))) + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, err.Error()) return } - args = append(args, fnArgs...) + outs := rpcFunc.f.Call(args) - returns := rpcFunc.f.Call(args) - - logger.Debug("HTTPRestRPC", "method", r.URL.Path, "args", args, "returns", returns) - result, err := unreflectResult(returns) + logger.Debug("HTTPRestRPC", "method", req.URL.Path, "args", args, "returns", outs) + result, err := unreflectResult(outs) switch e := err.(type) { // if no error then return a success response case nil: @@ -74,142 +72,135 @@ func makeHTTPHandler(rpcFunc *RPCFunc, logger log.Logger) func(http.ResponseWrit } } -// Covert an http query to a list of properly typed values. -// To be properly decoded the arg must be a concrete type from tendermint (if its an interface). -func httpParamsToArgs(rpcFunc *RPCFunc, r *http.Request) ([]reflect.Value, error) { - // skip types.Context - const argsOffset = 1 +func parseURLParams(ctx context.Context, rf *RPCFunc, req *http.Request) ([]reflect.Value, error) { + if err := req.ParseForm(); err != nil { + return nil, fmt.Errorf("invalid HTTP request: %w", err) + } + getArg := func(name string) (string, bool) { + if req.Form.Has(name) { + return req.Form.Get(name), true + } + return "", false + } - values := make([]reflect.Value, len(rpcFunc.argNames)) + vals := make([]reflect.Value, len(rf.argNames)+1) + vals[0] = reflect.ValueOf(ctx) + for i, name := range rf.argNames { + atype := rf.args[i+1] - for i, name := range rpcFunc.argNames { - argType := rpcFunc.args[i+argsOffset] - - values[i] = reflect.Zero(argType) // set default for that type - - arg := getParam(r, name) - // log.Notice("param to arg", "argType", argType, "name", name, "arg", arg) - - if arg == "" { + text, ok := getArg(name) + if !ok { + vals[i+1] = reflect.Zero(atype) continue } - v, ok, err := nonJSONStringToArg(argType, arg) + val, err := parseArgValue(atype, text) if err != nil { - return nil, err - } - if ok { - values[i] = v - continue - } - - values[i], err = jsonStringToArg(argType, arg) - if err != nil { - return nil, err + return nil, fmt.Errorf("decoding parameter %q: %w", name, err) } + vals[i+1] = val } - - return values, nil + return vals, nil } -func jsonStringToArg(rt reflect.Type, arg string) (reflect.Value, error) { - rv := reflect.New(rt) - err := tmjson.Unmarshal([]byte(arg), rv.Interface()) - if err != nil { - return rv, err - } - rv = rv.Elem() - return rv, nil -} - -func nonJSONStringToArg(rt reflect.Type, arg string) (reflect.Value, bool, error) { - if rt.Kind() == reflect.Ptr { - rv1, ok, err := nonJSONStringToArg(rt.Elem(), arg) - switch { - case err != nil: - return reflect.Value{}, false, err - case ok: - rv := reflect.New(rt.Elem()) - rv.Elem().Set(rv1) - return rv, true, nil - default: - return reflect.Value{}, false, nil - } +func parseArgValue(atype reflect.Type, text string) (reflect.Value, error) { + // Regardless whether the argument is a pointer type, allocate a pointer so + // we can set the computed value. + var out reflect.Value + isPtr := atype.Kind() == reflect.Ptr + if isPtr { + out = reflect.New(atype.Elem()) } else { - return _nonJSONStringToArg(rt, arg) + out = reflect.New(atype) + } + + baseType := out.Type().Elem() + if isIntType(baseType) { + // Integral type: Require a base-10 digit string. For compatibility with + // existing use allow quotation marks. + v, err := decodeInteger(text) + if err != nil { + return reflect.Value{}, fmt.Errorf("invalid integer: %w", err) + } + out.Elem().Set(reflect.ValueOf(v).Convert(baseType)) + } else if isStringOrBytes(baseType) { + // String or byte slice: Check for quotes, hex encoding. + dec, err := decodeString(text) + if err != nil { + return reflect.Value{}, err + } + out.Elem().Set(reflect.ValueOf(dec).Convert(baseType)) + + } else if baseType.Kind() == reflect.Bool { + b, err := strconv.ParseBool(text) + if err != nil { + return reflect.Value{}, fmt.Errorf("invalid boolean: %w", err) + } + out.Elem().Set(reflect.ValueOf(b)) + + } else { + // We don't know how to represent other types. + return reflect.Value{}, fmt.Errorf("unsupported argument type %v", baseType) + } + + // If the argument wants a pointer, return the value as-is, otherwise + // indirect the pointer back off. + if isPtr { + return out, nil + } + return out.Elem(), nil +} + +var uint64Type = reflect.TypeOf(uint64(0)) + +// isIntType reports whether atype is an integer-shaped type. +func isIntType(atype reflect.Type) bool { + switch atype.Kind() { + case reflect.Float32, reflect.Float64: + return false + default: + return atype.ConvertibleTo(uint64Type) } } -// NOTE: rt.Kind() isn't a pointer. -func _nonJSONStringToArg(rt reflect.Type, arg string) (reflect.Value, bool, error) { - isIntString := reInt.Match([]byte(arg)) - isQuotedString := strings.HasPrefix(arg, `"`) && strings.HasSuffix(arg, `"`) - isHexString := strings.HasPrefix(strings.ToLower(arg), "0x") - - var expectingString, expectingByteSlice, expectingInt bool - switch rt.Kind() { - case reflect.Int, - reflect.Uint, - reflect.Int8, - reflect.Uint8, - reflect.Int16, - reflect.Uint16, - reflect.Int32, - reflect.Uint32, - reflect.Int64, - reflect.Uint64: - expectingInt = true +// isStringOrBytes reports whether atype is a string or []byte. +func isStringOrBytes(atype reflect.Type) bool { + switch atype.Kind() { case reflect.String: - expectingString = true + return true case reflect.Slice: - expectingByteSlice = rt.Elem().Kind() == reflect.Uint8 + return atype.Elem().Kind() == reflect.Uint8 + default: + return false } - - if isIntString && expectingInt { - qarg := `"` + arg + `"` - rv, err := jsonStringToArg(rt, qarg) - if err != nil { - return rv, false, err - } - - return rv, true, nil - } - - if isHexString { - if !expectingString && !expectingByteSlice { - err := fmt.Errorf("got a hex string arg, but expected '%s'", - rt.Kind().String()) - return reflect.ValueOf(nil), false, err - } - - var value []byte - value, err := hex.DecodeString(arg[2:]) - if err != nil { - return reflect.ValueOf(nil), false, err - } - if rt.Kind() == reflect.String { - return reflect.ValueOf(string(value)), true, nil - } - return reflect.ValueOf(value), true, nil - } - - if isQuotedString && expectingByteSlice { - v := reflect.New(reflect.TypeOf("")) - err := tmjson.Unmarshal([]byte(arg), v.Interface()) - if err != nil { - return reflect.ValueOf(nil), false, err - } - v = v.Elem() - return reflect.ValueOf([]byte(v.String())), true, nil - } - - return reflect.ValueOf(nil), false, nil } -func getParam(r *http.Request, param string) string { - s := r.URL.Query().Get(param) - if s == "" { - s = r.FormValue(param) - } - return s +// isQuotedString reports whether s is enclosed in double quotes. +func isQuotedString(s string) bool { + return len(s) >= 2 && strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`) +} + +// decodeInteger decodes s into an int64. If s is "double quoted" the quotes +// are removed; otherwise s must be a base-10 digit string. +func decodeInteger(s string) (int64, error) { + if isQuotedString(s) { + s = s[1 : len(s)-1] + } + return strconv.ParseInt(s, 10, 64) +} + +// decodeString decodes s into a byte slice. If s has an 0x prefix, it is +// treated as a hex-encoded string. If it is "double quoted" it is treated as a +// JSON string value. Otherwise, s is converted to bytes directly. +func decodeString(s string) ([]byte, error) { + if lc := strings.ToLower(s); strings.HasPrefix(lc, "0x") { + return hex.DecodeString(lc[2:]) + } else if isQuotedString(s) { + var dec string + if err := json.Unmarshal([]byte(s), &dec); err != nil { + return nil, fmt.Errorf("invalid quoted string: %w", err) + } + return []byte(dec), nil + } + return []byte(s), nil } diff --git a/rpc/jsonrpc/server/parse_test.go b/rpc/jsonrpc/server/parse_test.go index 6533a5d44..9b222b507 100644 --- a/rpc/jsonrpc/server/parse_test.go +++ b/rpc/jsonrpc/server/parse_test.go @@ -187,8 +187,15 @@ func TestParseURI(t *testing.T) { // can parse numbers quoted, too {[]string{`"7"`, `"flew"`}, 7, "flew", false}, {[]string{`"-10"`, `"bob"`}, -10, "bob", false}, - // cant parse strings uquoted - {[]string{`"-10"`, `bob`}, -10, "bob", true}, + // can parse strings hex-escaped, in either case + {[]string{`-9`, `0x626f62`}, -9, "bob", false}, + {[]string{`-9`, `0X646F7567`}, -9, "doug", false}, + // can parse strings unquoted (as per OpenAPI docs) + {[]string{`0`, `hey you`}, 0, "hey you", false}, + // fail for invalid numbers, strings, hex + {[]string{`"-xx"`, `bob`}, 0, "", true}, // bad number + {[]string{`"95""`, `"bob`}, 0, "", true}, // bad string + {[]string{`15`, `0xa`}, 0, "", true}, // bad hex } for idx, tc := range cases { i := strconv.Itoa(idx) @@ -198,14 +205,14 @@ func TestParseURI(t *testing.T) { tc.raw[0], tc.raw[1]) req, err := http.NewRequest("GET", url, nil) assert.NoError(t, err) - vals, err := httpParamsToArgs(call, req) + vals, err := parseURLParams(context.Background(), call, req) if tc.fail { assert.Error(t, err, i) } else { assert.NoError(t, err, "%s: %+v", i, err) - if assert.Equal(t, 2, len(vals), i) { - assert.Equal(t, tc.height, vals[0].Int(), i) - assert.Equal(t, tc.name, vals[1].String(), i) + if assert.Equal(t, 3, len(vals), i) { + assert.Equal(t, tc.height, vals[1].Int(), i) + assert.Equal(t, tc.name, vals[2].String(), i) } } diff --git a/test/e2e/docker/Dockerfile b/test/e2e/docker/Dockerfile index 260df23f3..4e19fe9f8 100644 --- a/test/e2e/docker/Dockerfile +++ b/test/e2e/docker/Dockerfile @@ -1,7 +1,7 @@ # We need to build in a Linux environment to support C libraries, e.g. RocksDB. # We use Debian instead of Alpine, so that we can use binary database packages # instead of spending time compiling them. -FROM golang:1.16 +FROM golang:1.17 RUN apt-get -qq update -y && apt-get -qq upgrade -y >/dev/null RUN apt-get -qq install -y libleveldb-dev librocksdb-dev >/dev/null diff --git a/test/fuzz/rpc/jsonrpc/server/handler.go b/test/fuzz/rpc/jsonrpc/server/handler.go index e6dc577d7..92b02329f 100644 --- a/test/fuzz/rpc/jsonrpc/server/handler.go +++ b/test/fuzz/rpc/jsonrpc/server/handler.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -13,7 +14,9 @@ import ( ) var rpcFuncMap = map[string]*rs.RPCFunc{ - "c": rs.NewRPCFunc(func(s string, i int) (string, int) { return "foo", 200 }, "s", "i"), + "c": rs.NewRPCFunc(func(ctx context.Context, s string, i int) (string, error) { + return "foo", nil + }, "s", "i"), } var mux *http.ServeMux