mirror of
https://github.com/tendermint/tendermint.git
synced 2026-01-05 04:55:18 +00:00
privval: remove panics in privval implementation (#7475)
This commit is contained in:
@@ -65,7 +65,9 @@ func initFilesWithConfig(ctx context.Context, config *cfg.Config) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pv.Save()
|
||||
if err := pv.Save(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("Generated private validator", "keyFile", privValKeyFile,
|
||||
"stateFile", privValStateFile)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,9 @@ func resetFilePV(privValKeyFile, privValStateFile string, logger log.Logger) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pv.Reset()
|
||||
if err := pv.Reset(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("Reset private validator file to genesis state", "keyFile", privValKeyFile,
|
||||
"stateFile", privValStateFile)
|
||||
} else {
|
||||
@@ -76,7 +78,9 @@ func resetFilePV(privValKeyFile, privValStateFile string, logger log.Logger) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pv.Save()
|
||||
if err := pv.Save(); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Info("Generated private validator file", "keyFile", privValKeyFile,
|
||||
"stateFile", privValStateFile)
|
||||
}
|
||||
|
||||
@@ -487,15 +487,13 @@ func newStateWithConfigAndBlockStore(
|
||||
return cs
|
||||
}
|
||||
|
||||
func loadPrivValidator(cfg *config.Config) *privval.FilePV {
|
||||
func loadPrivValidator(t *testing.T, cfg *config.Config) *privval.FilePV {
|
||||
privValidatorKeyFile := cfg.PrivValidator.KeyFile()
|
||||
ensureDir(filepath.Dir(privValidatorKeyFile), 0700)
|
||||
privValidatorStateFile := cfg.PrivValidator.StateFile()
|
||||
privValidator, err := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
privValidator.Reset()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, privValidator.Reset())
|
||||
return privValidator
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ func startNewStateAndWaitForBlock(ctx context.Context, t *testing.T, consensusRe
|
||||
logger := log.TestingLogger()
|
||||
state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile())
|
||||
require.NoError(t, err)
|
||||
privValidator := loadPrivValidator(consensusReplayConfig)
|
||||
privValidator := loadPrivValidator(t, consensusReplayConfig)
|
||||
blockStore := store.NewBlockStore(dbm.NewMemDB())
|
||||
cs := newStateWithConfigAndBlockStore(
|
||||
ctx,
|
||||
@@ -165,7 +165,7 @@ LOOP:
|
||||
blockStore := store.NewBlockStore(blockDB)
|
||||
state, err := sm.MakeGenesisStateFromFile(consensusReplayConfig.GenesisFile())
|
||||
require.NoError(t, err)
|
||||
privValidator := loadPrivValidator(consensusReplayConfig)
|
||||
privValidator := loadPrivValidator(t, consensusReplayConfig)
|
||||
cs := newStateWithConfigAndBlockStore(
|
||||
rctx,
|
||||
logger,
|
||||
|
||||
131
privval/file.go
131
privval/file.go
@@ -32,14 +32,14 @@ const (
|
||||
)
|
||||
|
||||
// A vote is either stepPrevote or stepPrecommit.
|
||||
func voteToStep(vote *tmproto.Vote) int8 {
|
||||
func voteToStep(vote *tmproto.Vote) (int8, error) {
|
||||
switch vote.Type {
|
||||
case tmproto.PrevoteType:
|
||||
return stepPrevote
|
||||
return stepPrevote, nil
|
||||
case tmproto.PrecommitType:
|
||||
return stepPrecommit
|
||||
return stepPrecommit, nil
|
||||
default:
|
||||
panic(fmt.Sprintf("Unknown vote type: %v", vote.Type))
|
||||
return 0, fmt.Errorf("unknown vote type: %v", vote.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,21 +55,17 @@ type FilePVKey struct {
|
||||
}
|
||||
|
||||
// Save persists the FilePVKey to its filePath.
|
||||
func (pvKey FilePVKey) Save() {
|
||||
func (pvKey FilePVKey) Save() error {
|
||||
outFile := pvKey.filePath
|
||||
if outFile == "" {
|
||||
panic("cannot save PrivValidator key: filePath not set")
|
||||
return errors.New("cannot save PrivValidator key: filePath not set")
|
||||
}
|
||||
|
||||
jsonBytes, err := tmjson.MarshalIndent(pvKey, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
err = tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------------
|
||||
@@ -127,19 +123,16 @@ func (lss *FilePVLastSignState) CheckHRS(height int64, round int32, step int8) (
|
||||
}
|
||||
|
||||
// Save persists the FilePvLastSignState to its filePath.
|
||||
func (lss *FilePVLastSignState) Save() {
|
||||
func (lss *FilePVLastSignState) Save() error {
|
||||
outFile := lss.filePath
|
||||
if outFile == "" {
|
||||
panic("cannot save FilePVLastSignState: filePath not set")
|
||||
return errors.New("cannot save FilePVLastSignState: filePath not set")
|
||||
}
|
||||
jsonBytes, err := tmjson.MarshalIndent(lss, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
return tempfile.WriteFileAtomic(outFile, jsonBytes, 0600)
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------------
|
||||
@@ -239,17 +232,23 @@ func loadFilePV(keyFilePath, stateFilePath string, loadState bool) (*FilePV, err
|
||||
// LoadOrGenFilePV loads a FilePV from the given filePaths
|
||||
// or else generates a new one and saves it to the filePaths.
|
||||
func LoadOrGenFilePV(keyFilePath, stateFilePath string) (*FilePV, error) {
|
||||
var (
|
||||
pv *FilePV
|
||||
err error
|
||||
)
|
||||
if tmos.FileExists(keyFilePath) {
|
||||
pv, err = LoadFilePV(keyFilePath, stateFilePath)
|
||||
} else {
|
||||
pv, err = GenFilePV(keyFilePath, stateFilePath, "")
|
||||
pv.Save()
|
||||
pv, err := LoadFilePV(keyFilePath, stateFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pv, nil
|
||||
}
|
||||
return pv, err
|
||||
pv, err := GenFilePV(keyFilePath, stateFilePath, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := pv.Save(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pv, nil
|
||||
}
|
||||
|
||||
// GetAddress returns the address of the validator.
|
||||
@@ -283,21 +282,23 @@ func (pv *FilePV) SignProposal(ctx context.Context, chainID string, proposal *tm
|
||||
}
|
||||
|
||||
// Save persists the FilePV to disk.
|
||||
func (pv *FilePV) Save() {
|
||||
pv.Key.Save()
|
||||
pv.LastSignState.Save()
|
||||
func (pv *FilePV) Save() error {
|
||||
if err := pv.Key.Save(); err != nil {
|
||||
return err
|
||||
}
|
||||
return pv.LastSignState.Save()
|
||||
}
|
||||
|
||||
// Reset resets all fields in the FilePV.
|
||||
// NOTE: Unsafe!
|
||||
func (pv *FilePV) Reset() {
|
||||
func (pv *FilePV) Reset() error {
|
||||
var sig []byte
|
||||
pv.LastSignState.Height = 0
|
||||
pv.LastSignState.Round = 0
|
||||
pv.LastSignState.Step = 0
|
||||
pv.LastSignState.Signature = sig
|
||||
pv.LastSignState.SignBytes = nil
|
||||
pv.Save()
|
||||
return pv.Save()
|
||||
}
|
||||
|
||||
// String returns a string representation of the FilePV.
|
||||
@@ -317,8 +318,13 @@ func (pv *FilePV) String() string {
|
||||
// It may need to set the timestamp as well if the vote is otherwise the same as
|
||||
// a previously signed vote (ie. we crashed after signing but before the vote hit the WAL).
|
||||
func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
|
||||
height, round, step := vote.Height, vote.Round, voteToStep(vote)
|
||||
step, err := voteToStep(vote)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
height := vote.Height
|
||||
round := vote.Round
|
||||
lss := pv.LastSignState
|
||||
|
||||
sameHRS, err := lss.CheckHRS(height, round, step)
|
||||
@@ -336,13 +342,19 @@ func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
|
||||
if sameHRS {
|
||||
if bytes.Equal(signBytes, lss.SignBytes) {
|
||||
vote.Signature = lss.Signature
|
||||
} else if timestamp, ok := checkVotesOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
|
||||
} else {
|
||||
timestamp, ok, err := checkVotesOnlyDifferByTimestamp(lss.SignBytes, signBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errors.New("conflicting data")
|
||||
}
|
||||
|
||||
vote.Timestamp = timestamp
|
||||
vote.Signature = lss.Signature
|
||||
} else {
|
||||
err = fmt.Errorf("conflicting data")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// It passed the checks. Sign the vote
|
||||
@@ -350,7 +362,9 @@ func (pv *FilePV) signVote(chainID string, vote *tmproto.Vote) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pv.saveSigned(height, round, step, signBytes, sig)
|
||||
if err := pv.saveSigned(height, round, step, signBytes, sig); err != nil {
|
||||
return err
|
||||
}
|
||||
vote.Signature = sig
|
||||
return nil
|
||||
}
|
||||
@@ -378,13 +392,18 @@ func (pv *FilePV) signProposal(chainID string, proposal *tmproto.Proposal) error
|
||||
if sameHRS {
|
||||
if bytes.Equal(signBytes, lss.SignBytes) {
|
||||
proposal.Signature = lss.Signature
|
||||
} else if timestamp, ok := checkProposalsOnlyDifferByTimestamp(lss.SignBytes, signBytes); ok {
|
||||
} else {
|
||||
timestamp, ok, err := checkProposalsOnlyDifferByTimestamp(lss.SignBytes, signBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errors.New("conflicting data")
|
||||
}
|
||||
proposal.Timestamp = timestamp
|
||||
proposal.Signature = lss.Signature
|
||||
} else {
|
||||
err = fmt.Errorf("conflicting data")
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// It passed the checks. Sign the proposal
|
||||
@@ -392,34 +411,34 @@ func (pv *FilePV) signProposal(chainID string, proposal *tmproto.Proposal) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pv.saveSigned(height, round, step, signBytes, sig)
|
||||
if err := pv.saveSigned(height, round, step, signBytes, sig); err != nil {
|
||||
return err
|
||||
}
|
||||
proposal.Signature = sig
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist height/round/step and signature
|
||||
func (pv *FilePV) saveSigned(height int64, round int32, step int8,
|
||||
signBytes []byte, sig []byte) {
|
||||
|
||||
func (pv *FilePV) saveSigned(height int64, round int32, step int8, signBytes []byte, sig []byte) error {
|
||||
pv.LastSignState.Height = height
|
||||
pv.LastSignState.Round = round
|
||||
pv.LastSignState.Step = step
|
||||
pv.LastSignState.Signature = sig
|
||||
pv.LastSignState.SignBytes = signBytes
|
||||
pv.LastSignState.Save()
|
||||
return pv.LastSignState.Save()
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------------------
|
||||
|
||||
// returns the timestamp from the lastSignBytes.
|
||||
// returns true if the only difference in the votes is their timestamp.
|
||||
func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
|
||||
func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool, error) {
|
||||
var lastVote, newVote tmproto.CanonicalVote
|
||||
if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil {
|
||||
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into vote: %v", err))
|
||||
return time.Time{}, false, fmt.Errorf("LastSignBytes cannot be unmarshalled into vote: %v", err)
|
||||
}
|
||||
if err := protoio.UnmarshalDelimited(newSignBytes, &newVote); err != nil {
|
||||
panic(fmt.Sprintf("signBytes cannot be unmarshalled into vote: %v", err))
|
||||
return time.Time{}, false, fmt.Errorf("signBytes cannot be unmarshalled into vote: %v", err)
|
||||
}
|
||||
|
||||
lastTime := lastVote.Timestamp
|
||||
@@ -428,18 +447,18 @@ func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.T
|
||||
lastVote.Timestamp = now
|
||||
newVote.Timestamp = now
|
||||
|
||||
return lastTime, proto.Equal(&newVote, &lastVote)
|
||||
return lastTime, proto.Equal(&newVote, &lastVote), nil
|
||||
}
|
||||
|
||||
// returns the timestamp from the lastSignBytes.
|
||||
// returns true if the only difference in the proposals is their timestamp
|
||||
func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
|
||||
func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool, error) {
|
||||
var lastProposal, newProposal tmproto.CanonicalProposal
|
||||
if err := protoio.UnmarshalDelimited(lastSignBytes, &lastProposal); err != nil {
|
||||
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into proposal: %v", err))
|
||||
return time.Time{}, false, fmt.Errorf("LastSignBytes cannot be unmarshalled into proposal: %v", err)
|
||||
}
|
||||
if err := protoio.UnmarshalDelimited(newSignBytes, &newProposal); err != nil {
|
||||
panic(fmt.Sprintf("signBytes cannot be unmarshalled into proposal: %v", err))
|
||||
return time.Time{}, false, fmt.Errorf("signBytes cannot be unmarshalled into proposal: %v", err)
|
||||
}
|
||||
|
||||
lastTime := lastProposal.Timestamp
|
||||
@@ -448,5 +467,5 @@ func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (ti
|
||||
lastProposal.Timestamp = now
|
||||
newProposal.Timestamp = now
|
||||
|
||||
return lastTime, proto.Equal(&newProposal, &lastProposal)
|
||||
return lastTime, proto.Equal(&newProposal, &lastProposal), nil
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestGenLoadValidator(t *testing.T) {
|
||||
|
||||
height := int64(100)
|
||||
privVal.LastSignState.Height = height
|
||||
privVal.Save()
|
||||
require.NoError(t, privVal.Save())
|
||||
addr := privVal.GetAddress()
|
||||
|
||||
privVal, err = LoadFilePV(tempKeyFile.Name(), tempStateFile.Name())
|
||||
@@ -68,7 +68,7 @@ func TestResetValidator(t *testing.T) {
|
||||
assert.NotEqual(t, privVal.LastSignState, emptyState)
|
||||
|
||||
// priv val after AcceptNewConnection is same as empty
|
||||
privVal.Reset()
|
||||
require.NoError(t, privVal.Reset())
|
||||
assert.Equal(t, privVal.LastSignState, emptyState)
|
||||
}
|
||||
|
||||
@@ -267,6 +267,9 @@ func TestSignProposal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDifferByTimestamp(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tempKeyFile, err := os.CreateTemp("", "priv_validator_key_")
|
||||
require.Nil(t, err)
|
||||
tempStateFile, err := os.CreateTemp("", "priv_validator_state_")
|
||||
@@ -283,8 +286,8 @@ func TestDifferByTimestamp(t *testing.T) {
|
||||
{
|
||||
proposal := newProposal(height, round, block1)
|
||||
pb := proposal.ToProto()
|
||||
err := privVal.SignProposal(context.Background(), chainID, pb)
|
||||
assert.NoError(t, err, "expected no error signing proposal")
|
||||
err := privVal.SignProposal(ctx, chainID, pb)
|
||||
require.NoError(t, err, "expected no error signing proposal")
|
||||
signBytes := types.ProposalSignBytes(chainID, pb)
|
||||
|
||||
sig := proposal.Signature
|
||||
@@ -294,8 +297,8 @@ func TestDifferByTimestamp(t *testing.T) {
|
||||
pb.Timestamp = pb.Timestamp.Add(time.Millisecond)
|
||||
var emptySig []byte
|
||||
proposal.Signature = emptySig
|
||||
err = privVal.SignProposal(context.Background(), "mychainid", pb)
|
||||
assert.NoError(t, err, "expected no error on signing same proposal")
|
||||
err = privVal.SignProposal(ctx, "mychainid", pb)
|
||||
require.NoError(t, err, "expected no error on signing same proposal")
|
||||
|
||||
assert.Equal(t, timeStamp, pb.Timestamp)
|
||||
assert.Equal(t, signBytes, types.ProposalSignBytes(chainID, pb))
|
||||
@@ -308,8 +311,8 @@ func TestDifferByTimestamp(t *testing.T) {
|
||||
blockID := types.BlockID{Hash: randbytes, PartSetHeader: types.PartSetHeader{}}
|
||||
vote := newVote(privVal.Key.Address, 0, height, round, voteType, blockID)
|
||||
v := vote.ToProto()
|
||||
err := privVal.SignVote(context.Background(), "mychainid", v)
|
||||
assert.NoError(t, err, "expected no error signing vote")
|
||||
err := privVal.SignVote(ctx, "mychainid", v)
|
||||
require.NoError(t, err, "expected no error signing vote")
|
||||
|
||||
signBytes := types.VoteSignBytes(chainID, v)
|
||||
sig := v.Signature
|
||||
@@ -319,8 +322,8 @@ func TestDifferByTimestamp(t *testing.T) {
|
||||
v.Timestamp = v.Timestamp.Add(time.Millisecond)
|
||||
var emptySig []byte
|
||||
v.Signature = emptySig
|
||||
err = privVal.SignVote(context.Background(), "mychainid", v)
|
||||
assert.NoError(t, err, "expected no error on signing same vote")
|
||||
err = privVal.SignVote(ctx, "mychainid", v)
|
||||
require.NoError(t, err, "expected no error on signing same vote")
|
||||
|
||||
assert.Equal(t, timeStamp, v.Timestamp)
|
||||
assert.Equal(t, signBytes, types.VoteSignBytes(chainID, v))
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
|
||||
const chainID = "chain-id"
|
||||
|
||||
func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(context.Context, string) (net.Conn, error)) {
|
||||
func dialer(t *testing.T, pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(context.Context, string) (net.Conn, error)) {
|
||||
listener := bufconn.Listen(1024 * 1024)
|
||||
|
||||
server := grpc.NewServer()
|
||||
@@ -33,11 +33,7 @@ func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(conte
|
||||
|
||||
privvalproto.RegisterPrivValidatorAPIServer(server, s)
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
go func() { require.NoError(t, server.Serve(listener)) }()
|
||||
|
||||
return server, func(context.Context, string) (net.Conn, error) {
|
||||
return listener.Dial()
|
||||
@@ -46,44 +42,43 @@ func dialer(pv types.PrivValidator, logger log.Logger) (*grpc.Server, func(conte
|
||||
|
||||
func TestSignerClient_GetPubKey(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mockPV := types.NewMockPV()
|
||||
logger := log.TestingLogger()
|
||||
srv, dialer := dialer(mockPV, logger)
|
||||
srv, dialer := dialer(t, mockPV, logger)
|
||||
defer srv.Stop()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, "",
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(dialer),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := client.GetPubKey(context.Background())
|
||||
pk, err := client.GetPubKey(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, mockPV.PrivKey.PubKey(), pk)
|
||||
}
|
||||
|
||||
func TestSignerClient_SignVote(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ctx := context.Background()
|
||||
mockPV := types.NewMockPV()
|
||||
logger := log.TestingLogger()
|
||||
srv, dialer := dialer(mockPV, logger)
|
||||
srv, dialer := dialer(t, mockPV, logger)
|
||||
defer srv.Stop()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, "",
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(dialer),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
|
||||
@@ -115,31 +110,30 @@ func TestSignerClient_SignVote(t *testing.T) {
|
||||
|
||||
pbHave := have.ToProto()
|
||||
|
||||
err = client.SignVote(context.Background(), chainID, pbHave)
|
||||
err = client.SignVote(ctx, chainID, pbHave)
|
||||
require.NoError(t, err)
|
||||
|
||||
pbWant := want.ToProto()
|
||||
|
||||
require.NoError(t, mockPV.SignVote(context.Background(), chainID, pbWant))
|
||||
require.NoError(t, mockPV.SignVote(ctx, chainID, pbWant))
|
||||
|
||||
assert.Equal(t, pbWant.Signature, pbHave.Signature)
|
||||
}
|
||||
|
||||
func TestSignerClient_SignProposal(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ctx := context.Background()
|
||||
mockPV := types.NewMockPV()
|
||||
logger := log.TestingLogger()
|
||||
srv, dialer := dialer(mockPV, logger)
|
||||
srv, dialer := dialer(t, mockPV, logger)
|
||||
defer srv.Stop()
|
||||
|
||||
conn, err := grpc.DialContext(ctx, "",
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(dialer),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client, err := tmgrpc.NewSignerClient(conn, chainID, logger)
|
||||
@@ -167,12 +161,12 @@ func TestSignerClient_SignProposal(t *testing.T) {
|
||||
|
||||
pbHave := have.ToProto()
|
||||
|
||||
err = client.SignProposal(context.Background(), chainID, pbHave)
|
||||
err = client.SignProposal(ctx, chainID, pbHave)
|
||||
require.NoError(t, err)
|
||||
|
||||
pbWant := want.ToProto()
|
||||
|
||||
require.NoError(t, mockPV.SignProposal(context.Background(), chainID, pbWant))
|
||||
require.NoError(t, mockPV.SignProposal(ctx, chainID, pbWant))
|
||||
|
||||
assert.Equal(t, pbWant.Signature, pbHave.Signature)
|
||||
}
|
||||
|
||||
@@ -99,7 +99,10 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (*
|
||||
)
|
||||
|
||||
// Generate ephemeral keys for perfect forward secrecy.
|
||||
locEphPub, locEphPriv := genEphKeys()
|
||||
locEphPub, locEphPriv, err := genEphKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write local ephemeral pubkey and receive one too.
|
||||
// NOTE: every 32-byte string is accepted as a Curve25519 public key (see
|
||||
@@ -132,7 +135,10 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (*
|
||||
// Generate the secret used for receiving, sending, challenge via HKDF-SHA2
|
||||
// on the transcript state (which itself also uses HKDF-SHA2 to derive a key
|
||||
// from the dhSecret).
|
||||
recvSecret, sendSecret := deriveSecrets(dhSecret, locIsLeast)
|
||||
recvSecret, sendSecret, err := deriveSecrets(dhSecret, locIsLeast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const challengeSize = 32
|
||||
var challenge [challengeSize]byte
|
||||
@@ -214,7 +220,10 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) {
|
||||
|
||||
// encrypt the frame
|
||||
sc.sendAead.Seal(sealedFrame[:0], sc.sendNonce[:], frame, nil)
|
||||
incrNonce(sc.sendNonce)
|
||||
if err := incrNonce(sc.sendNonce); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// end encryption
|
||||
|
||||
_, err = sc.conn.Write(sealedFrame)
|
||||
@@ -258,7 +267,9 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) {
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("failed to decrypt SecretConnection: %w", err)
|
||||
}
|
||||
incrNonce(sc.recvNonce)
|
||||
if err = incrNonce(sc.recvNonce); err != nil {
|
||||
return
|
||||
}
|
||||
// end decryption
|
||||
|
||||
// copy checkLength worth into data,
|
||||
@@ -288,14 +299,13 @@ func (sc *SecretConnection) SetWriteDeadline(t time.Time) error {
|
||||
return sc.conn.(net.Conn).SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func genEphKeys() (ephPub, ephPriv *[32]byte) {
|
||||
var err error
|
||||
func genEphKeys() (ephPub, ephPriv *[32]byte, err error) {
|
||||
// TODO: Probably not a problem but ask Tony: different from the rust implementation (uses x25519-dalek),
|
||||
// we do not "clamp" the private key scalar:
|
||||
// see: https://github.com/dalek-cryptography/x25519-dalek/blob/34676d336049df2bba763cc076a75e47ae1f170f/src/x25519.rs#L56-L74
|
||||
ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
|
||||
if err != nil {
|
||||
panic("Could not generate ephemeral key-pair")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -339,14 +349,14 @@ func shareEphPubKey(conn io.ReadWriter, locEphPub *[32]byte) (remEphPub *[32]byt
|
||||
func deriveSecrets(
|
||||
dhSecret *[32]byte,
|
||||
locIsLeast bool,
|
||||
) (recvSecret, sendSecret *[aeadKeySize]byte) {
|
||||
) (recvSecret, sendSecret *[aeadKeySize]byte, err error) {
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, dhSecret[:], nil, secretConnKeyAndChallengeGen)
|
||||
// get enough data for 2 aead keys, and a 32 byte challenge
|
||||
res := new([2*aeadKeySize + 32]byte)
|
||||
_, err := io.ReadFull(hkdf, res[:])
|
||||
_, err = io.ReadFull(hkdf, res[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
recvSecret = new([aeadKeySize]byte)
|
||||
@@ -454,13 +464,14 @@ func shareAuthSignature(sc io.ReadWriter, pubKey crypto.PubKey, signature []byte
|
||||
// Due to chacha20poly1305 expecting a 12 byte nonce we do not use the first four
|
||||
// bytes. We only increment a 64 bit unsigned int in the remaining 8 bytes
|
||||
// (little-endian in nonce[4:]).
|
||||
func incrNonce(nonce *[aeadNonceSize]byte) {
|
||||
func incrNonce(nonce *[aeadNonceSize]byte) error {
|
||||
counter := binary.LittleEndian.Uint64(nonce[4:])
|
||||
if counter == math.MaxUint64 {
|
||||
// Terminates the session and makes sure the nonce would not re-used.
|
||||
// See https://github.com/tendermint/tendermint/issues/3531
|
||||
panic("can't increase nonce without overflow")
|
||||
return errors.New("can't increase nonce without overflow")
|
||||
}
|
||||
counter++
|
||||
binary.LittleEndian.PutUint64(nonce[4:], counter)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
|
||||
mockPV = types.NewMockPV()
|
||||
endpointIsOpenCh = make(chan struct{})
|
||||
thisConnTimeout = testTimeoutReadWrite
|
||||
listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout)
|
||||
listenerEndpoint = newSignerListenerEndpoint(t, logger, tc.addr, thisConnTimeout)
|
||||
)
|
||||
|
||||
dialerEndpoint := NewSignerDialerEndpoint(
|
||||
@@ -138,14 +138,12 @@ func TestRetryConnToRemoteSigner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
|
||||
func newSignerListenerEndpoint(t *testing.T, logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
|
||||
proto, address := tmnet.ProtocolAndAddress(addr)
|
||||
|
||||
ln, err := net.Listen(proto, address)
|
||||
logger.Info("SignerListener: Listening", "proto", proto, "address", address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
var listener net.Listener
|
||||
|
||||
@@ -199,7 +197,7 @@ func getMockEndpoints(
|
||||
socketDialer,
|
||||
)
|
||||
|
||||
listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite)
|
||||
listenerEndpoint = newSignerListenerEndpoint(t, logger, addr, testTimeoutReadWrite)
|
||||
)
|
||||
|
||||
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
|
||||
|
||||
@@ -9,10 +9,20 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tendermint/tendermint/crypto/ed25519"
|
||||
tmnet "github.com/tendermint/tendermint/libs/net"
|
||||
)
|
||||
|
||||
// getFreeLocalhostAddrPort returns a free localhost:port address
|
||||
func getFreeLocalhostAddrPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
port, err := tmnet.GetFreePort()
|
||||
require.NoError(t, err)
|
||||
|
||||
return fmt.Sprintf("127.0.0.1:%d", port)
|
||||
}
|
||||
|
||||
func getDialerTestCases(t *testing.T) []dialerTestCase {
|
||||
tcpAddr := GetFreeLocalhostAddrPort()
|
||||
tcpAddr := getFreeLocalhostAddrPort(t)
|
||||
unixFilePath, err := testUnixAddr()
|
||||
require.NoError(t, err)
|
||||
unixAddr := fmt.Sprintf("unix://%s", unixFilePath)
|
||||
@@ -31,7 +41,7 @@ func getDialerTestCases(t *testing.T) []dialerTestCase {
|
||||
|
||||
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) {
|
||||
// Generate a networking timeout
|
||||
tcpAddr := GetFreeLocalhostAddrPort()
|
||||
tcpAddr := getFreeLocalhostAddrPort(t)
|
||||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
|
||||
_, err := dialer()
|
||||
assert.Error(t, err)
|
||||
@@ -39,7 +49,7 @@ func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) {
|
||||
tcpAddr := GetFreeLocalhostAddrPort()
|
||||
tcpAddr := getFreeLocalhostAddrPort(t)
|
||||
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
|
||||
_, err := dialer()
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tendermint/tendermint/crypto/ed25519"
|
||||
)
|
||||
|
||||
@@ -107,9 +108,7 @@ func TestListenerTimeoutReadWrite(t *testing.T) {
|
||||
for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) {
|
||||
go func(dialer SocketDialer) {
|
||||
_, err := dialer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}(tc.dialer)
|
||||
|
||||
c, err := tc.listener.Accept()
|
||||
|
||||
@@ -51,12 +51,3 @@ func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEnd
|
||||
|
||||
return pve, nil
|
||||
}
|
||||
|
||||
// GetFreeLocalhostAddrPort returns a free localhost:port address
|
||||
func GetFreeLocalhostAddrPort() string {
|
||||
port, err := tmnet.GetFreePort()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf("127.0.0.1:%d", port)
|
||||
}
|
||||
|
||||
@@ -111,17 +111,23 @@ func Setup(testnet *e2e.Testnet) error {
|
||||
return err
|
||||
}
|
||||
|
||||
(privval.NewFilePV(node.PrivvalKey,
|
||||
err = (privval.NewFilePV(node.PrivvalKey,
|
||||
filepath.Join(nodeDir, PrivvalKeyFile),
|
||||
filepath.Join(nodeDir, PrivvalStateFile),
|
||||
)).Save()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up a dummy validator. Tendermint requires a file PV even when not used, so we
|
||||
// give it a dummy such that it will fail if it actually tries to use it.
|
||||
(privval.NewFilePV(ed25519.GenPrivKey(),
|
||||
err = (privval.NewFilePV(ed25519.GenPrivKey(),
|
||||
filepath.Join(nodeDir, PrivvalDummyKeyFile),
|
||||
filepath.Join(nodeDir, PrivvalDummyStateFile),
|
||||
)).Save()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user