Extract priv_validator into first class package

This is a maintenance change to move the private validator package out
of the types and to a top-level location. There is no good reason to
keep it under the types and it will more clearly coommunicate where
additions related to the privval belong. It leaves the interface and the
mock in types for now as it would introduce circular dependency between
privval and types, this should be resolved eventually.

* mv priv_validator to privval pkg
* use consistent `privval` as import

Follow-up to #1255
This commit is contained in:
Alexander Simmerl
2018-06-01 19:17:37 +02:00
parent 1c643701f5
commit bf370d36c2
19 changed files with 33 additions and 33 deletions

View File

@@ -1,345 +0,0 @@
package privval
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"sync"
"time"
"github.com/tendermint/go-crypto"
"github.com/tendermint/tendermint/types"
cmn "github.com/tendermint/tmlibs/common"
)
// TODO: type ?
const (
stepNone int8 = 0 // Used to distinguish the initial state
stepPropose int8 = 1
stepPrevote int8 = 2
stepPrecommit int8 = 3
)
func voteToStep(vote *types.Vote) int8 {
switch vote.Type {
case types.VoteTypePrevote:
return stepPrevote
case types.VoteTypePrecommit:
return stepPrecommit
default:
cmn.PanicSanity("Unknown vote type")
return 0
}
}
// FilePV implements PrivValidator using data persisted to disk
// to prevent double signing.
// NOTE: the directory containing the pv.filePath must already exist.
type FilePV struct {
Address types.Address `json:"address"`
PubKey crypto.PubKey `json:"pub_key"`
LastHeight int64 `json:"last_height"`
LastRound int `json:"last_round"`
LastStep int8 `json:"last_step"`
LastSignature crypto.Signature `json:"last_signature,omitempty"` // so we dont lose signatures XXX Why would we lose signatures?
LastSignBytes cmn.HexBytes `json:"last_signbytes,omitempty"` // so we dont lose signatures XXX Why would we lose signatures?
PrivKey crypto.PrivKey `json:"priv_key"`
// For persistence.
// Overloaded for testing.
filePath string
mtx sync.Mutex
}
// GetAddress returns the address of the validator.
// Implements PrivValidator.
func (pv *FilePV) GetAddress() types.Address {
return pv.Address
}
// GetPubKey returns the public key of the validator.
// Implements PrivValidator.
func (pv *FilePV) GetPubKey() crypto.PubKey {
return pv.PubKey
}
// GenFilePV generates a new validator with randomly generated private key
// and sets the filePath, but does not call Save().
func GenFilePV(filePath string) *FilePV {
privKey := crypto.GenPrivKeyEd25519()
return &FilePV{
Address: privKey.PubKey().Address(),
PubKey: privKey.PubKey(),
PrivKey: privKey,
LastStep: stepNone,
filePath: filePath,
}
}
// LoadFilePV loads a FilePV from the filePath. The FilePV handles double
// signing prevention by persisting data to the filePath. If the filePath does
// not exist, the FilePV must be created manually and saved.
func LoadFilePV(filePath string) *FilePV {
pvJSONBytes, err := ioutil.ReadFile(filePath)
if err != nil {
cmn.Exit(err.Error())
}
pv := &FilePV{}
err = cdc.UnmarshalJSON(pvJSONBytes, &pv)
if err != nil {
cmn.Exit(cmn.Fmt("Error reading PrivValidator from %v: %v\n", filePath, err))
}
pv.filePath = filePath
return pv
}
// LoadOrGenFilePV loads a FilePV from the given filePath
// or else generates a new one and saves it to the filePath.
func LoadOrGenFilePV(filePath string) *FilePV {
var pv *FilePV
if cmn.FileExists(filePath) {
pv = LoadFilePV(filePath)
} else {
pv = GenFilePV(filePath)
pv.Save()
}
return pv
}
// Save persists the FilePV to disk.
func (pv *FilePV) Save() {
pv.mtx.Lock()
defer pv.mtx.Unlock()
pv.save()
}
func (pv *FilePV) save() {
outFile := pv.filePath
if outFile == "" {
panic("Cannot save PrivValidator: filePath not set")
}
jsonBytes, err := cdc.MarshalJSONIndent(pv, "", " ")
if err != nil {
panic(err)
}
err = cmn.WriteFileAtomic(outFile, jsonBytes, 0600)
if err != nil {
panic(err)
}
}
// Reset resets all fields in the FilePV.
// NOTE: Unsafe!
func (pv *FilePV) Reset() {
var sig crypto.Signature
pv.LastHeight = 0
pv.LastRound = 0
pv.LastStep = 0
pv.LastSignature = sig
pv.LastSignBytes = nil
pv.Save()
}
// SignVote signs a canonical representation of the vote, along with the
// chainID. Implements PrivValidator.
func (pv *FilePV) SignVote(chainID string, vote *types.Vote) error {
pv.mtx.Lock()
defer pv.mtx.Unlock()
if err := pv.signVote(chainID, vote); err != nil {
return errors.New(cmn.Fmt("Error signing vote: %v", err))
}
return nil
}
// SignProposal signs a canonical representation of the proposal, along with
// the chainID. Implements PrivValidator.
func (pv *FilePV) SignProposal(chainID string, proposal *types.Proposal) error {
pv.mtx.Lock()
defer pv.mtx.Unlock()
if err := pv.signProposal(chainID, proposal); err != nil {
return fmt.Errorf("Error signing proposal: %v", err)
}
return nil
}
// returns error if HRS regression or no LastSignBytes. returns true if HRS is unchanged
func (pv *FilePV) checkHRS(height int64, round int, step int8) (bool, error) {
if pv.LastHeight > height {
return false, errors.New("Height regression")
}
if pv.LastHeight == height {
if pv.LastRound > round {
return false, errors.New("Round regression")
}
if pv.LastRound == round {
if pv.LastStep > step {
return false, errors.New("Step regression")
} else if pv.LastStep == step {
if pv.LastSignBytes != nil {
if pv.LastSignature == nil {
panic("pv: LastSignature is nil but LastSignBytes is not!")
}
return true, nil
}
return false, errors.New("No LastSignature found")
}
}
}
return false, nil
}
// signVote checks if the vote is good to sign and sets the vote signature.
// It may need to set the timestamp as well if the vote is otherwise the same as
// a previously signed vote (ie. we crashed after signing but before the vote hit the WAL).
func (pv *FilePV) signVote(chainID string, vote *types.Vote) error {
height, round, step := vote.Height, vote.Round, voteToStep(vote)
signBytes := vote.SignBytes(chainID)
sameHRS, err := pv.checkHRS(height, round, step)
if err != nil {
return err
}
// We might crash before writing to the wal,
// causing us to try to re-sign for the same HRS.
// If signbytes are the same, use the last signature.
// If they only differ by timestamp, use last timestamp and signature
// Otherwise, return error
if sameHRS {
if bytes.Equal(signBytes, pv.LastSignBytes) {
vote.Signature = pv.LastSignature
} else if timestamp, ok := checkVotesOnlyDifferByTimestamp(pv.LastSignBytes, signBytes); ok {
vote.Timestamp = timestamp
vote.Signature = pv.LastSignature
} else {
err = fmt.Errorf("Conflicting data")
}
return err
}
// It passed the checks. Sign the vote
sig := pv.PrivKey.Sign(signBytes)
pv.saveSigned(height, round, step, signBytes, sig)
vote.Signature = sig
return nil
}
// signProposal checks if the proposal is good to sign and sets the proposal signature.
// It may need to set the timestamp as well if the proposal is otherwise the same as
// a previously signed proposal ie. we crashed after signing but before the proposal hit the WAL).
func (pv *FilePV) signProposal(chainID string, proposal *types.Proposal) error {
height, round, step := proposal.Height, proposal.Round, stepPropose
signBytes := proposal.SignBytes(chainID)
sameHRS, err := pv.checkHRS(height, round, step)
if err != nil {
return err
}
// We might crash before writing to the wal,
// causing us to try to re-sign for the same HRS.
// If signbytes are the same, use the last signature.
// If they only differ by timestamp, use last timestamp and signature
// Otherwise, return error
if sameHRS {
if bytes.Equal(signBytes, pv.LastSignBytes) {
proposal.Signature = pv.LastSignature
} else if timestamp, ok := checkProposalsOnlyDifferByTimestamp(pv.LastSignBytes, signBytes); ok {
proposal.Timestamp = timestamp
proposal.Signature = pv.LastSignature
} else {
err = fmt.Errorf("Conflicting data")
}
return err
}
// It passed the checks. Sign the proposal
sig := pv.PrivKey.Sign(signBytes)
pv.saveSigned(height, round, step, signBytes, sig)
proposal.Signature = sig
return nil
}
// Persist height/round/step and signature
func (pv *FilePV) saveSigned(height int64, round int, step int8,
signBytes []byte, sig crypto.Signature) {
pv.LastHeight = height
pv.LastRound = round
pv.LastStep = step
pv.LastSignature = sig
pv.LastSignBytes = signBytes
pv.save()
}
// SignHeartbeat signs a canonical representation of the heartbeat, along with the chainID.
// Implements PrivValidator.
func (pv *FilePV) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
pv.mtx.Lock()
defer pv.mtx.Unlock()
heartbeat.Signature = pv.PrivKey.Sign(heartbeat.SignBytes(chainID))
return nil
}
// String returns a string representation of the FilePV.
func (pv *FilePV) String() string {
return fmt.Sprintf("PrivValidator{%v LH:%v, LR:%v, LS:%v}", pv.GetAddress(), pv.LastHeight, pv.LastRound, pv.LastStep)
}
//-------------------------------------
// returns the timestamp from the lastSignBytes.
// returns true if the only difference in the votes is their timestamp.
func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
var lastVote, newVote types.CanonicalJSONVote
if err := cdc.UnmarshalJSON(lastSignBytes, &lastVote); err != nil {
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into vote: %v", err))
}
if err := cdc.UnmarshalJSON(newSignBytes, &newVote); err != nil {
panic(fmt.Sprintf("signBytes cannot be unmarshalled into vote: %v", err))
}
lastTime, err := time.Parse(types.TimeFormat, lastVote.Timestamp)
if err != nil {
panic(err)
}
// set the times to the same value and check equality
now := types.CanonicalTime(time.Now())
lastVote.Timestamp = now
newVote.Timestamp = now
lastVoteBytes, _ := cdc.MarshalJSON(lastVote)
newVoteBytes, _ := cdc.MarshalJSON(newVote)
return lastTime, bytes.Equal(newVoteBytes, lastVoteBytes)
}
// returns the timestamp from the lastSignBytes.
// returns true if the only difference in the proposals is their timestamp
func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) {
var lastProposal, newProposal types.CanonicalJSONProposal
if err := cdc.UnmarshalJSON(lastSignBytes, &lastProposal); err != nil {
panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into proposal: %v", err))
}
if err := cdc.UnmarshalJSON(newSignBytes, &newProposal); err != nil {
panic(fmt.Sprintf("signBytes cannot be unmarshalled into proposal: %v", err))
}
lastTime, err := time.Parse(types.TimeFormat, lastProposal.Timestamp)
if err != nil {
panic(err)
}
// set the times to the same value and check equality
now := types.CanonicalTime(time.Now())
lastProposal.Timestamp = now
newProposal.Timestamp = now
lastProposalBytes, _ := cdc.MarshalJSON(lastProposal)
newProposalBytes, _ := cdc.MarshalJSON(newProposal)
return lastTime, bytes.Equal(newProposalBytes, lastProposalBytes)
}

View File

@@ -1,251 +0,0 @@
package privval
import (
"encoding/base64"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/go-crypto"
"github.com/tendermint/tendermint/types"
cmn "github.com/tendermint/tmlibs/common"
)
func TestGenLoadValidator(t *testing.T) {
assert := assert.New(t)
_, tempFilePath := cmn.Tempfile("priv_validator_")
privVal := GenFilePV(tempFilePath)
height := int64(100)
privVal.LastHeight = height
privVal.Save()
addr := privVal.GetAddress()
privVal = LoadFilePV(tempFilePath)
assert.Equal(addr, privVal.GetAddress(), "expected privval addr to be the same")
assert.Equal(height, privVal.LastHeight, "expected privval.LastHeight to have been saved")
}
func TestLoadOrGenValidator(t *testing.T) {
assert := assert.New(t)
_, tempFilePath := cmn.Tempfile("priv_validator_")
if err := os.Remove(tempFilePath); err != nil {
t.Error(err)
}
privVal := LoadOrGenFilePV(tempFilePath)
addr := privVal.GetAddress()
privVal = LoadOrGenFilePV(tempFilePath)
assert.Equal(addr, privVal.GetAddress(), "expected privval addr to be the same")
}
func TestUnmarshalValidator(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// create some fixed values
privKey := crypto.GenPrivKeyEd25519()
pubKey := privKey.PubKey()
addr := pubKey.Address()
pubArray := [32]byte(pubKey.(crypto.PubKeyEd25519))
pubBytes := pubArray[:]
privArray := [64]byte(privKey)
privBytes := privArray[:]
pubB64 := base64.StdEncoding.EncodeToString(pubBytes)
privB64 := base64.StdEncoding.EncodeToString(privBytes)
serialized := fmt.Sprintf(`{
"address": "%s",
"pub_key": {
"type": "AC26791624DE60",
"value": "%s"
},
"last_height": 0,
"last_round": 0,
"last_step": 0,
"priv_key": {
"type": "954568A3288910",
"value": "%s"
}
}`, addr, pubB64, privB64)
val := FilePV{}
err := cdc.UnmarshalJSON([]byte(serialized), &val)
require.Nil(err, "%+v", err)
// make sure the values match
assert.EqualValues(addr, val.GetAddress())
assert.EqualValues(pubKey, val.GetPubKey())
assert.EqualValues(privKey, val.PrivKey)
// export it and make sure it is the same
out, err := cdc.MarshalJSON(val)
require.Nil(err, "%+v", err)
assert.JSONEq(serialized, string(out))
}
func TestSignVote(t *testing.T) {
assert := assert.New(t)
_, tempFilePath := cmn.Tempfile("priv_validator_")
privVal := GenFilePV(tempFilePath)
block1 := types.BlockID{[]byte{1, 2, 3}, types.PartSetHeader{}}
block2 := types.BlockID{[]byte{3, 2, 1}, types.PartSetHeader{}}
height, round := int64(10), 1
voteType := types.VoteTypePrevote
// sign a vote for first time
vote := newVote(privVal.Address, 0, height, round, voteType, block1)
err := privVal.SignVote("mychainid", vote)
assert.NoError(err, "expected no error signing vote")
// try to sign the same vote again; should be fine
err = privVal.SignVote("mychainid", vote)
assert.NoError(err, "expected no error on signing same vote")
// now try some bad votes
cases := []*types.Vote{
newVote(privVal.Address, 0, height, round-1, voteType, block1), // round regression
newVote(privVal.Address, 0, height-1, round, voteType, block1), // height regression
newVote(privVal.Address, 0, height-2, round+4, voteType, block1), // height regression and different round
newVote(privVal.Address, 0, height, round, voteType, block2), // different block
}
for _, c := range cases {
err = privVal.SignVote("mychainid", c)
assert.Error(err, "expected error on signing conflicting vote")
}
// try signing a vote with a different time stamp
sig := vote.Signature
vote.Timestamp = vote.Timestamp.Add(time.Duration(1000))
err = privVal.SignVote("mychainid", vote)
assert.NoError(err)
assert.Equal(sig, vote.Signature)
}
func TestSignProposal(t *testing.T) {
assert := assert.New(t)
_, tempFilePath := cmn.Tempfile("priv_validator_")
privVal := GenFilePV(tempFilePath)
block1 := types.PartSetHeader{5, []byte{1, 2, 3}}
block2 := types.PartSetHeader{10, []byte{3, 2, 1}}
height, round := int64(10), 1
// sign a proposal for first time
proposal := newProposal(height, round, block1)
err := privVal.SignProposal("mychainid", proposal)
assert.NoError(err, "expected no error signing proposal")
// try to sign the same proposal again; should be fine
err = privVal.SignProposal("mychainid", proposal)
assert.NoError(err, "expected no error on signing same proposal")
// now try some bad Proposals
cases := []*types.Proposal{
newProposal(height, round-1, block1), // round regression
newProposal(height-1, round, block1), // height regression
newProposal(height-2, round+4, block1), // height regression and different round
newProposal(height, round, block2), // different block
}
for _, c := range cases {
err = privVal.SignProposal("mychainid", c)
assert.Error(err, "expected error on signing conflicting proposal")
}
// try signing a proposal with a different time stamp
sig := proposal.Signature
proposal.Timestamp = proposal.Timestamp.Add(time.Duration(1000))
err = privVal.SignProposal("mychainid", proposal)
assert.NoError(err)
assert.Equal(sig, proposal.Signature)
}
func TestDifferByTimestamp(t *testing.T) {
_, tempFilePath := cmn.Tempfile("priv_validator_")
privVal := GenFilePV(tempFilePath)
block1 := types.PartSetHeader{5, []byte{1, 2, 3}}
height, round := int64(10), 1
chainID := "mychainid"
// test proposal
{
proposal := newProposal(height, round, block1)
err := privVal.SignProposal(chainID, proposal)
assert.NoError(t, err, "expected no error signing proposal")
signBytes := proposal.SignBytes(chainID)
sig := proposal.Signature
timeStamp := clipToMS(proposal.Timestamp)
// manipulate the timestamp. should get changed back
proposal.Timestamp = proposal.Timestamp.Add(time.Millisecond)
var emptySig crypto.Signature
proposal.Signature = emptySig
err = privVal.SignProposal("mychainid", proposal)
assert.NoError(t, err, "expected no error on signing same proposal")
assert.Equal(t, timeStamp, proposal.Timestamp)
assert.Equal(t, signBytes, proposal.SignBytes(chainID))
assert.Equal(t, sig, proposal.Signature)
}
// test vote
{
voteType := types.VoteTypePrevote
blockID := types.BlockID{[]byte{1, 2, 3}, types.PartSetHeader{}}
vote := newVote(privVal.Address, 0, height, round, voteType, blockID)
err := privVal.SignVote("mychainid", vote)
assert.NoError(t, err, "expected no error signing vote")
signBytes := vote.SignBytes(chainID)
sig := vote.Signature
timeStamp := clipToMS(vote.Timestamp)
// manipulate the timestamp. should get changed back
vote.Timestamp = vote.Timestamp.Add(time.Millisecond)
var emptySig crypto.Signature
vote.Signature = emptySig
err = privVal.SignVote("mychainid", vote)
assert.NoError(t, err, "expected no error on signing same vote")
assert.Equal(t, timeStamp, vote.Timestamp)
assert.Equal(t, signBytes, vote.SignBytes(chainID))
assert.Equal(t, sig, vote.Signature)
}
}
func newVote(addr types.Address, idx int, height int64, round int, typ byte, blockID types.BlockID) *types.Vote {
return &types.Vote{
ValidatorAddress: addr,
ValidatorIndex: idx,
Height: height,
Round: round,
Type: typ,
Timestamp: time.Now().UTC(),
BlockID: blockID,
}
}
func newProposal(height int64, round int, partsHeader types.PartSetHeader) *types.Proposal {
return &types.Proposal{
Height: height,
Round: round,
BlockPartsHeader: partsHeader,
Timestamp: time.Now().UTC(),
}
}
func clipToMS(t time.Time) time.Time {
nano := t.UnixNano()
million := int64(1000000)
nano = (nano / million) * million
return time.Unix(0, nano).UTC()
}

View File

@@ -1,538 +0,0 @@
package privval
import (
"errors"
"fmt"
"io"
"net"
"time"
"github.com/tendermint/go-amino"
"github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
p2pconn "github.com/tendermint/tendermint/p2p/conn"
"github.com/tendermint/tendermint/types"
)
const (
defaultAcceptDeadlineSeconds = 3
defaultConnDeadlineSeconds = 3
defaultConnHeartBeatSeconds = 30
defaultConnWaitSeconds = 60
defaultDialRetries = 10
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("dialed maximum retries")
ErrConnWaitTimeout = errors.New("waited for remote signer for too long")
ErrConnTimeout = errors.New("remote signer timed out")
)
var (
acceptDeadline = time.Second + defaultAcceptDeadlineSeconds
connDeadline = time.Second * defaultConnDeadlineSeconds
connHeartbeat = time.Second * defaultConnHeartBeatSeconds
)
// SocketPVOption sets an optional parameter on the SocketPV.
type SocketPVOption func(*SocketPV)
// SocketPVAcceptDeadline sets the deadline for the SocketPV listener.
// A zero time value disables the deadline.
func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.acceptDeadline = deadline }
}
// SocketPVConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketPVConnDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connDeadline = deadline }
}
// SocketPVHeartbeat sets the period on which to check the liveness of the
// connected Signer connections.
func SocketPVHeartbeat(period time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connHeartbeat = period }
}
// SocketPVConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketPVConnWait(timeout time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connWaitTimeout = timeout }
}
// SocketPV implements PrivValidator, it uses a socket to request signatures
// from an external process.
type SocketPV struct {
cmn.BaseService
addr string
acceptDeadline time.Duration
connDeadline time.Duration
connHeartbeat time.Duration
connWaitTimeout time.Duration
privKey crypto.PrivKeyEd25519
conn net.Conn
listener net.Listener
}
// Check that SocketPV implements PrivValidator.
var _ types.PrivValidator = (*SocketPV)(nil)
// NewSocketPV returns an instance of SocketPV.
func NewSocketPV(
logger log.Logger,
socketAddr string,
privKey crypto.PrivKeyEd25519,
) *SocketPV {
sc := &SocketPV{
addr: socketAddr,
acceptDeadline: acceptDeadline,
connDeadline: connDeadline,
connHeartbeat: connHeartbeat,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc)
return sc
}
// GetAddress implements PrivValidator.
func (sc *SocketPV) GetAddress() types.Address {
addr, err := sc.getAddress()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *SocketPV) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
}
// GetPubKey implements PrivValidator.
func (sc *SocketPV) GetPubKey() crypto.PubKey {
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return pubKey
}
func (sc *SocketPV) getPubKey() (crypto.PubKey, error) {
err := writeMsg(sc.conn, &PubKeyMsg{})
if err != nil {
return nil, err
}
res, err := readMsg(sc.conn)
if err != nil {
return nil, err
}
return res.(*PubKeyMsg).PubKey, nil
}
// SignVote implements PrivValidator.
func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error {
err := writeMsg(sc.conn, &SignVoteMsg{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
*vote = *res.(*SignVoteMsg).Vote
return nil
}
// SignProposal implements PrivValidator.
func (sc *SocketPV) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
*proposal = *res.(*SignProposalMsg).Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (sc *SocketPV) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
*heartbeat = *res.(*SignHeartbeatMsg).Heartbeat
return nil
}
// OnStart implements cmn.Service.
func (sc *SocketPV) OnStart() error {
if err := sc.listen(); err != nil {
err = cmn.ErrorWrap(err, "failed to listen")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
conn, err := sc.waitConnection()
if err != nil {
err = cmn.ErrorWrap(err, "failed to accept connection")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
sc.conn = conn
return nil
}
// OnStop implements cmn.Service.
func (sc *SocketPV) OnStop() {
if sc.conn != nil {
if err := sc.conn.Close(); err != nil {
err = cmn.ErrorWrap(err, "failed to close connection")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
}
if sc.listener != nil {
if err := sc.listener.Close(); err != nil {
err = cmn.ErrorWrap(err, "failed to close listener")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
}
}
func (sc *SocketPV) acceptConnection() (net.Conn, error) {
conn, err := sc.listener.Accept()
if err != nil {
if !sc.IsRunning() {
return nil, nil // Ignore error from listener closing.
}
return nil, err
}
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (sc *SocketPV) listen() error {
ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
if err != nil {
return err
}
sc.listener = newTCPTimeoutListener(
ln,
sc.acceptDeadline,
sc.connDeadline,
sc.connHeartbeat,
)
return nil
}
// waitConnection uses the configured wait timeout to error if no external
// process connects in the time period.
func (sc *SocketPV) waitConnection() (net.Conn, error) {
var (
connc = make(chan net.Conn, 1)
errc = make(chan error, 1)
)
go func(connc chan<- net.Conn, errc chan<- error) {
conn, err := sc.acceptConnection()
if err != nil {
errc <- err
return
}
connc <- conn
}(connc, errc)
select {
case conn := <-connc:
return conn, nil
case err := <-errc:
if _, ok := err.(timeoutError); ok {
return nil, cmn.ErrorWrap(ErrConnWaitTimeout, err.Error())
}
return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
}
}
//---------------------------------------------------------
// RemoteSignerOption sets an optional parameter on the RemoteSigner.
type RemoteSignerOption func(*RemoteSigner)
// RemoteSignerConnDeadline sets the read and write deadline for connections
// from external signing processes.
func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connDeadline = deadline }
}
// RemoteSignerConnRetries sets the amount of attempted retries to connect.
func RemoteSignerConnRetries(retries int) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connRetries = retries }
}
// RemoteSigner implements PrivValidator by dialing to a socket.
type RemoteSigner struct {
cmn.BaseService
addr string
chainID string
connDeadline time.Duration
connRetries int
privKey crypto.PrivKeyEd25519
privVal types.PrivValidator
conn net.Conn
}
// NewRemoteSigner returns an instance of RemoteSigner.
func NewRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
privVal types.PrivValidator,
privKey crypto.PrivKeyEd25519,
) *RemoteSigner {
rs := &RemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privKey: privKey,
privVal: privVal,
}
rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (rs *RemoteSigner) OnStart() error {
conn, err := rs.connect()
if err != nil {
err = cmn.ErrorWrap(err, "connect")
rs.Logger.Error("OnStart", "err", err)
return err
}
go rs.handleConnection(conn)
return nil
}
// OnStop implements cmn.Service.
func (rs *RemoteSigner) OnStop() {
if rs.conn == nil {
return
}
if err := rs.conn.Close(); err != nil {
rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
}
}
func (rs *RemoteSigner) connect() (net.Conn, error) {
for retries := rs.connRetries; retries > 0; retries-- {
// Don't sleep if it is the first retry.
if retries != rs.connRetries {
time.Sleep(rs.connDeadline)
}
conn, err := cmn.Connect(rs.addr)
if err != nil {
err = cmn.ErrorWrap(err, "connection failed")
rs.Logger.Error(
"connect",
"addr", rs.addr,
"err", err,
)
continue
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
err = cmn.ErrorWrap(err, "setting connection timeout failed")
rs.Logger.Error(
"connect",
"err", err,
)
continue
}
conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey)
if err != nil {
err = cmn.ErrorWrap(err, "encrypting connection failed")
rs.Logger.Error(
"connect",
"err", err,
)
continue
}
return conn, nil
}
return nil, ErrDialRetryMax
}
func (rs *RemoteSigner) handleConnection(conn net.Conn) {
for {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
rs.Logger.Error("handleConnection", "err", err)
}
return
}
var res SocketPVMsg
switch r := req.(type) {
case *PubKeyMsg:
var p crypto.PubKey
p = rs.privVal.GetPubKey()
res = &PubKeyMsg{p}
case *SignVoteMsg:
err = rs.privVal.SignVote(rs.chainID, r.Vote)
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
res = &SignHeartbeatMsg{r.Heartbeat}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
if err != nil {
rs.Logger.Error("handleConnection", "err", err)
return
}
err = writeMsg(conn, res)
if err != nil {
rs.Logger.Error("handleConnection", "err", err)
return
}
}
}
//---------------------------------------------------------
// SocketPVMsg is sent between RemoteSigner and SocketPV.
type SocketPVMsg interface{}
func RegisterSocketPVMsg(cdc *amino.Codec) {
cdc.RegisterInterface((*SocketPVMsg)(nil), nil)
cdc.RegisterConcrete(&PubKeyMsg{}, "tendermint/socketpv/PubKeyMsg", nil)
cdc.RegisterConcrete(&SignVoteMsg{}, "tendermint/socketpv/SignVoteMsg", nil)
cdc.RegisterConcrete(&SignProposalMsg{}, "tendermint/socketpv/SignProposalMsg", nil)
cdc.RegisterConcrete(&SignHeartbeatMsg{}, "tendermint/socketpv/SignHeartbeatMsg", nil)
}
// PubKeyMsg is a PrivValidatorSocket message containing the public key.
type PubKeyMsg struct {
PubKey crypto.PubKey
}
// SignVoteMsg is a PrivValidatorSocket message containing a vote.
type SignVoteMsg struct {
Vote *types.Vote
}
// SignProposalMsg is a PrivValidatorSocket message containing a Proposal.
type SignProposalMsg struct {
Proposal *types.Proposal
}
// SignHeartbeatMsg is a PrivValidatorSocket message containing a Heartbeat.
type SignHeartbeatMsg struct {
Heartbeat *types.Heartbeat
}
func readMsg(r io.Reader) (msg SocketPVMsg, err error) {
const maxSocketPVMsgSize = 1024 * 10
_, err = cdc.UnmarshalBinaryReader(r, &msg, maxSocketPVMsgSize)
if _, ok := err.(timeoutError); ok {
err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
}
return
}
func writeMsg(w io.Writer, msg interface{}) (err error) {
_, err = cdc.MarshalBinaryWriter(w, msg)
if _, ok := err.(timeoutError); ok {
err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
}
return
}

View File

@@ -1,66 +0,0 @@
package privval
import (
"net"
"time"
)
// timeoutError can be used to check if an error returned from the netp package
// was due to a timeout.
type timeoutError interface {
Timeout() bool
}
// tcpTimeoutListener implements net.Listener.
var _ net.Listener = (*tcpTimeoutListener)(nil)
// tcpTimeoutListener wraps a *net.TCPListener to standardise protocol timeouts
// and potentially other tuning parameters.
type tcpTimeoutListener struct {
*net.TCPListener
acceptDeadline time.Duration
connDeadline time.Duration
period time.Duration
}
// newTCPTimeoutListener returns an instance of tcpTimeoutListener.
func newTCPTimeoutListener(
ln net.Listener,
acceptDeadline, connDeadline time.Duration,
period time.Duration,
) tcpTimeoutListener {
return tcpTimeoutListener{
TCPListener: ln.(*net.TCPListener),
acceptDeadline: acceptDeadline,
connDeadline: connDeadline,
period: period,
}
}
// Accept implements net.Listener.
func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline))
if err != nil {
return nil, err
}
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil {
return nil, err
}
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(ln.period); err != nil {
return nil, err
}
return tc, nil
}

View File

@@ -1,64 +0,0 @@
package privval
import (
"net"
"testing"
"time"
)
func TestTCPTimeoutListenerAcceptDeadline(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
ln = newTCPTimeoutListener(ln, time.Millisecond, time.Second, time.Second)
_, err = ln.Accept()
opErr, ok := err.(*net.OpError)
if !ok {
t.Fatalf("have %v, want *net.OpError", err)
}
if have, want := opErr.Op, "accept"; have != want {
t.Errorf("have %v, want %v", have, want)
}
}
func TestTCPTimeoutListenerConnDeadline(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
ln = newTCPTimeoutListener(ln, time.Second, time.Millisecond, time.Second)
donec := make(chan struct{})
go func(ln net.Listener) {
defer close(donec)
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(2 * time.Millisecond)
_, err = c.Write([]byte("foo"))
opErr, ok := err.(*net.OpError)
if !ok {
t.Fatalf("have %v, want *net.OpError", err)
}
if have, want := opErr.Op, "write"; have != want {
t.Errorf("have %v, want %v", have, want)
}
}(ln)
_, err = net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
<-donec
}

View File

@@ -1,282 +0,0 @@
package privval
import (
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
p2pconn "github.com/tendermint/tendermint/p2p/conn"
"github.com/tendermint/tendermint/types"
)
func TestSocketPVAddress(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer rs.Stop()
serverAddr := rs.privVal.GetAddress()
clientAddr, err := sc.getAddress()
require.NoError(t, err)
assert.Equal(t, serverAddr, clientAddr)
// TODO(xla): Remove when PrivValidator2 replaced PrivValidator.
assert.Equal(t, serverAddr, sc.GetAddress())
}
func TestSocketPVPubKey(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer rs.Stop()
clientKey, err := sc.getPubKey()
require.NoError(t, err)
privKey := rs.privVal.GetPubKey()
assert.Equal(t, privKey, clientKey)
// TODO(xla): Remove when PrivValidator2 replaced PrivValidator.
assert.Equal(t, privKey, sc.GetPubKey())
}
func TestSocketPVProposal(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
privProposal = &types.Proposal{Timestamp: ts}
clientProposal = &types.Proposal{Timestamp: ts}
)
defer sc.Stop()
defer rs.Stop()
require.NoError(t, rs.privVal.SignProposal(chainID, privProposal))
require.NoError(t, sc.SignProposal(chainID, clientProposal))
assert.Equal(t, privProposal.Signature, clientProposal.Signature)
}
func TestSocketPVVote(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestSocketPVHeartbeat(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
want = &types.Heartbeat{}
have = &types.Heartbeat{}
)
defer sc.Stop()
defer rs.Stop()
require.NoError(t, rs.privVal.SignHeartbeat(chainID, want))
require.NoError(t, sc.SignHeartbeat(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestSocketPVAcceptDeadline(t *testing.T) {
var (
sc = NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
crypto.GenPrivKeyEd25519(),
)
)
defer sc.Stop()
SocketPVAcceptDeadline(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Cause(), ErrConnWaitTimeout)
}
func TestSocketPVDeadline(t *testing.T) {
var (
addr = testFreeAddr(t)
listenc = make(chan struct{})
sc = NewSocketPV(
log.TestingLogger(),
addr,
crypto.GenPrivKeyEd25519(),
)
)
SocketPVConnDeadline(100 * time.Millisecond)(sc)
SocketPVConnWait(500 * time.Millisecond)(sc)
go func(sc *SocketPV) {
defer close(listenc)
require.NoError(t, sc.Start())
assert.True(t, sc.IsRunning())
}(sc)
for {
conn, err := cmn.Connect(addr)
if err != nil {
continue
}
_, err = p2pconn.MakeSecretConnection(
conn,
crypto.GenPrivKeyEd25519(),
)
if err == nil {
break
}
}
<-listenc
// Sleep to guarantee deadline has been hit.
time.Sleep(20 * time.Microsecond)
_, err := sc.getPubKey()
assert.Equal(t, err.(cmn.Error).Cause(), ErrConnTimeout)
}
func TestSocketPVWait(t *testing.T) {
sc := NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
crypto.GenPrivKeyEd25519(),
)
defer sc.Stop()
SocketPVConnWait(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Cause(), ErrConnWaitTimeout)
}
func TestRemoteSignerRetry(t *testing.T) {
var (
attemptc = make(chan int)
retries = 2
)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
go func(ln net.Listener, attemptc chan<- int) {
attempts := 0
for {
conn, err := ln.Accept()
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
attempts++
if attempts == retries {
attemptc <- attempts
break
}
}
}(ln, attemptc)
rs := NewRemoteSigner(
log.TestingLogger(),
cmn.RandStr(12),
ln.Addr().String(),
types.NewMockPV(),
crypto.GenPrivKeyEd25519(),
)
defer rs.Stop()
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(retries)(rs)
assert.Equal(t, rs.Start().(cmn.Error).Cause(), ErrDialRetryMax)
select {
case attempts := <-attemptc:
assert.Equal(t, retries, attempts)
case <-time.After(100 * time.Millisecond):
t.Error("expected remote to observe connection attempts")
}
}
func testSetupSocketPair(
t *testing.T,
chainID string,
) (*SocketPV, *RemoteSigner) {
var (
addr = testFreeAddr(t)
logger = log.TestingLogger()
privVal = types.NewMockPV()
readyc = make(chan struct{})
rs = NewRemoteSigner(
logger,
chainID,
addr,
privVal,
crypto.GenPrivKeyEd25519(),
)
sc = NewSocketPV(
logger,
addr,
crypto.GenPrivKeyEd25519(),
)
)
go func(sc *SocketPV) {
require.NoError(t, sc.Start())
assert.True(t, sc.IsRunning())
readyc <- struct{}{}
}(sc)
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(1e6)(rs)
require.NoError(t, rs.Start())
assert.True(t, rs.IsRunning())
<-readyc
return sc, rs
}
// testFreeAddr claims a free port so we don't block on listener being ready.
func testFreeAddr(t *testing.T) string {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port)
}

View File

@@ -1,13 +0,0 @@
package privval
import (
"github.com/tendermint/go-amino"
"github.com/tendermint/go-crypto"
)
var cdc = amino.NewCodec()
func init() {
crypto.RegisterAmino(cdc)
RegisterSocketPVMsg(cdc)
}