cmd: add integration test and fix bug in rollback command (#7315)

This commit is contained in:
Callum Waters
2021-12-02 12:17:16 +01:00
committed by GitHub
parent 5f57d84dd3
commit bca2080c01
13 changed files with 199 additions and 45 deletions

View File

@@ -80,7 +80,7 @@ func DefaultConfig(dir string) *Config {
// NewApplication creates the application.
func NewApplication(cfg *Config) (*Application, error) {
state, err := NewState(filepath.Join(cfg.Dir, "state.json"), cfg.PersistInterval)
state, err := NewState(cfg.Dir, cfg.PersistInterval)
if err != nil {
return nil, err
}
@@ -267,6 +267,10 @@ func (app *Application) ApplySnapshotChunk(req abci.RequestApplySnapshotChunk) a
return abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}
}
func (app *Application) Rollback() error {
return app.state.Rollback()
}
// validatorUpdates generates a validator set update.
func (app *Application) validatorUpdates(height uint64) (abci.ValidatorUpdates, error) {
updates := app.cfg.ValidatorUpdates[fmt.Sprintf("%v", height)]

View File

@@ -7,10 +7,14 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
)
const stateFileName = "app_state.json"
const prevStateFileName = "prev_app_state.json"
// State is the application state.
type State struct {
sync.RWMutex
@@ -19,16 +23,19 @@ type State struct {
Hash []byte
// private fields aren't marshaled to disk.
file string
currentFile string
// app saves current and previous state for rollback functionality
previousFile string
persistInterval uint64
initialHeight uint64
}
// NewState creates a new state.
func NewState(file string, persistInterval uint64) (*State, error) {
func NewState(dir string, persistInterval uint64) (*State, error) {
state := &State{
Values: make(map[string]string),
file: file,
currentFile: filepath.Join(dir, stateFileName),
previousFile: filepath.Join(dir, prevStateFileName),
persistInterval: persistInterval,
}
state.Hash = hashItems(state.Values)
@@ -44,13 +51,22 @@ func NewState(file string, persistInterval uint64) (*State, error) {
// load loads state from disk. It does not take out a lock, since it is called
// during construction.
func (s *State) load() error {
bz, err := os.ReadFile(s.file)
bz, err := os.ReadFile(s.currentFile)
if err != nil {
return fmt.Errorf("failed to read state from %q: %w", s.file, err)
// if the current state doesn't exist then we try recover from the previous state
if errors.Is(err, os.ErrNotExist) {
bz, err = os.ReadFile(s.previousFile)
if err != nil {
return fmt.Errorf("failed to read both current and previous state (%q): %w",
s.previousFile, err)
}
} else {
return fmt.Errorf("failed to read state from %q: %w", s.currentFile, err)
}
}
err = json.Unmarshal(bz, s)
if err != nil {
return fmt.Errorf("invalid state data in %q: %w", s.file, err)
return fmt.Errorf("invalid state data in %q: %w", s.currentFile, err)
}
return nil
}
@@ -64,12 +80,19 @@ func (s *State) save() error {
}
// We write the state to a separate file and move it to the destination, to
// make it atomic.
newFile := fmt.Sprintf("%v.new", s.file)
newFile := fmt.Sprintf("%v.new", s.currentFile)
err = os.WriteFile(newFile, bz, 0644)
if err != nil {
return fmt.Errorf("failed to write state to %q: %w", s.file, err)
return fmt.Errorf("failed to write state to %q: %w", s.currentFile, err)
}
return os.Rename(newFile, s.file)
// We take the current state and move it to the previous state, replacing it
if _, err := os.Stat(s.currentFile); err == nil {
if err := os.Rename(s.currentFile, s.previousFile); err != nil {
return fmt.Errorf("failed to replace previous state: %w", err)
}
}
// Finally, we take the new state and replace the current state.
return os.Rename(newFile, s.currentFile)
}
// Export exports key/value pairs as JSON, used for state sync snapshots.
@@ -135,6 +158,18 @@ func (s *State) Commit() (uint64, []byte, error) {
return s.Height, s.Hash, nil
}
func (s *State) Rollback() error {
bz, err := os.ReadFile(s.previousFile)
if err != nil {
return fmt.Errorf("failed to read state from %q: %w", s.previousFile, err)
}
err = json.Unmarshal(bz, s)
if err != nil {
return fmt.Errorf("invalid state data in %q: %w", s.previousFile, err)
}
return nil
}
// hashItems hashes a set of key/value items.
func hashItems(items map[string]string) []byte {
keys := make([]string, 0, len(items))