migration of privval module to gRPC

This commit is contained in:
Marko Baricevic
2020-06-09 17:51:30 +02:00
parent 6961c7e5d1
commit 4d9f573bf3
29 changed files with 383 additions and 1928 deletions

View File

@@ -3,22 +3,23 @@ package main
import (
"flag"
"os"
"time"
"github.com/tendermint/tendermint/crypto/ed25519"
"github.com/tendermint/tendermint/libs/log"
tmnet "github.com/tendermint/tendermint/libs/net"
tmos "github.com/tendermint/tendermint/libs/os"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/tendermint/tendermint/privval"
)
func main() {
var (
addr = flag.String("addr", ":26659", "Address of client to connect to")
addr = flag.String("addr", "tcp://127.0.0.1:26659", "Address of client to connect to")
chainID = flag.String("chain-id", "mychain", "chain id")
privValKeyPath = flag.String("priv-key", "", "priv val key file path")
privValStatePath = flag.String("priv-state", "", "priv val state file path")
withCert = flag.String("cert", "", "absolutepath to server certificate")
withKey = flag.String("key", "", "absolutepath to server key")
logger = log.NewTMLogger(
log.NewSyncWriter(os.Stdout),
@@ -36,21 +37,18 @@ func main() {
pv := privval.LoadFilePV(*privValKeyPath, *privValStatePath)
var dialer privval.SocketDialer
protocol, address := tmnet.ProtocolAndAddress(*addr)
switch protocol {
case "unix":
dialer = privval.DialUnixFn(address)
case "tcp":
connTimeout := 3 * time.Second // TODO
dialer = privval.DialTCPFn(address, connTimeout, ed25519.GenPrivKey())
default:
logger.Error("Unknown protocol", "protocol", protocol)
os.Exit(1)
opts := []grpc.ServerOption{}
if *withCert != "" && *withKey != "" {
creds, err := credentials.NewServerTLSFromFile(*withCert, *withKey)
if err != nil {
logger.Error("Could not load TLS keys:", "err", err)
}
opts = append(opts, grpc.Creds(creds))
} else {
logger.Error("You are using an insecure gRPC connection! Provide a certificate and key to connect securely")
}
sd := privval.NewSignerDialerEndpoint(logger, dialer)
ss := privval.NewSignerServer(sd, *chainID, pv)
ss := privval.NewSignerServer(*addr, *chainID, pv, logger, opts)
err := ss.Start()
if err != nil {

View File

@@ -46,6 +46,8 @@ var (
defaultGenesisJSONPath = filepath.Join(defaultConfigDir, defaultGenesisJSONName)
defaultPrivValKeyPath = filepath.Join(defaultConfigDir, defaultPrivValKeyName)
defaultPrivValStatePath = filepath.Join(defaultDataDir, defaultPrivValStateName)
// if a certificate is not provided the privval connection with a remote signer will be insecure
defaultPrivValClientCertificate = ""
defaultNodeKeyPath = filepath.Join(defaultConfigDir, defaultNodeKeyName)
defaultAddrBookPath = filepath.Join(defaultConfigDir, defaultAddrBookName)
@@ -201,6 +203,10 @@ type BaseConfig struct { //nolint: maligned
// connections from an external PrivValidator process
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`
// Path to client certificate file for secure private validator connection.
// If a remote validator address is provided but no certificate, the connection will be insecure
PrivValidatorClientCertificate string `mapstructure:"priv_validator_client_certificate"`
// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`
@@ -218,20 +224,21 @@ type BaseConfig struct { //nolint: maligned
// DefaultBaseConfig returns a default base configuration for a Tendermint node
func DefaultBaseConfig() BaseConfig {
return BaseConfig{
Genesis: defaultGenesisJSONPath,
PrivValidatorKey: defaultPrivValKeyPath,
PrivValidatorState: defaultPrivValStatePath,
NodeKey: defaultNodeKeyPath,
Moniker: defaultMoniker,
ProxyApp: "tcp://127.0.0.1:26658",
ABCI: "socket",
LogLevel: DefaultPackageLogLevels(),
LogFormat: LogFormatPlain,
ProfListenAddress: "",
FastSyncMode: true,
FilterPeers: false,
DBBackend: "goleveldb",
DBPath: "data",
Genesis: defaultGenesisJSONPath,
PrivValidatorKey: defaultPrivValKeyPath,
PrivValidatorState: defaultPrivValStatePath,
PrivValidatorClientCertificate: defaultPrivValClientCertificate,
NodeKey: defaultNodeKeyPath,
Moniker: defaultMoniker,
ProxyApp: "tcp://127.0.0.1:26658",
ABCI: "socket",
LogLevel: DefaultPackageLogLevels(),
LogFormat: LogFormatPlain,
ProfListenAddress: "",
FastSyncMode: true,
FilterPeers: false,
DBBackend: "goleveldb",
DBPath: "data",
}
}

View File

@@ -128,6 +128,10 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"
# connections from an external PrivValidator process
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"
# Path to client certificate file for secure private validator connection.
# If a remote validator address is provided but no certificate, the connection will be insecure
priv_validator_client_certificate = "{{ js .BaseConfig.PrivValidatorClientCertificate }}"
# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"

2
go.mod
View File

@@ -13,12 +13,14 @@ require (
github.com/gogo/protobuf v1.3.1
github.com/golang/protobuf v1.4.0
github.com/gorilla/websocket v1.4.2
github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4
github.com/gtank/merlin v0.1.1
github.com/libp2p/go-buffer-pool v0.0.2
github.com/magiconair/properties v1.8.1
github.com/minio/highwayhash v1.0.0
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.6.0
github.com/prometheus/common v0.9.1
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0
github.com/rs/cors v1.7.0
github.com/snikch/goodman v0.0.0-20171125024755-10e37e294daa

5
go.sum
View File

@@ -28,8 +28,10 @@ github.com/aead/siphash v1.0.1 h1:FwHfE/T45KPKYuuSAKyyvE+oPWcaQ+CUmFW0bPlM+kg=
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1CELW+XaDYmOH4hkBN4/N9og/AsOv7E=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
@@ -199,6 +201,7 @@ github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoA
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 h1:z53tR0945TRRQO/fLEVPI6SMv7ZflF0TEaTAoU7tOzg=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=
@@ -399,6 +402,7 @@ github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
@@ -654,6 +658,7 @@ google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zim
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -8,7 +8,6 @@ import (
"net"
"net/http"
_ "net/http/pprof" // nolint: gosec // securely exposed on separate, optional port
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
@@ -662,7 +661,7 @@ func NewNode(config *cfg.Config,
// external signing process.
if config.PrivValidatorListenAddr != "" {
// FIXME: we should start services inside OnStart
privValidator, err = createAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, logger)
privValidator, err = createAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, config.PrivValidatorClientCertificate, logger)
if err != nil {
return nil, fmt.Errorf("error with private validator socket client: %w", err)
}
@@ -1312,14 +1311,13 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) {
func createAndStartPrivValidatorSocketClient(
listenAddr string,
cert string,
logger log.Logger,
) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(listenAddr, logger)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
pvsc, err := privval.NewSignerClient(pve)
dialOptions := ConstructDialOptions(cert)
pvsc, err := privval.NewSignerClient(listenAddr, dialOptions, logger)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
@@ -1334,28 +1332,6 @@ func createAndStartPrivValidatorSocketClient(
retries = 50 // 50 * 100ms = 5s total
timeout = 100 * time.Millisecond
)
pvscWithRetries := privval.NewRetrySignerClient(pvsc, retries, timeout)
return pvscWithRetries, nil
}
// splitAndTrimEmpty slices s into all subslices separated by sep and returns a
// slice of the string s with all leading and trailing Unicode code points
// contained in cutset removed. If sep is empty, SplitAndTrim splits after each
// UTF-8 sequence. First part is equivalent to strings.SplitN with a count of
// -1. also filter out empty strings, only return non-empty strings.
func splitAndTrimEmpty(s, sep, cutset string) []string {
if s == "" {
return []string{}
}
spl := strings.Split(s, sep)
nonEmptyStrings := make([]string, 0, len(spl))
for i := 0; i < len(spl); i++ {
element := strings.Trim(spl[i], cutset)
if element != "" {
nonEmptyStrings = append(nonEmptyStrings, element)
}
}
return nonEmptyStrings
return pvsc, nil
}

View File

@@ -73,25 +73,6 @@ func TestNodeStartStop(t *testing.T) {
}
}
func TestSplitAndTrimEmpty(t *testing.T) {
testCases := []struct {
s string
sep string
cutset string
expected []string
}{
{"a,b,c", ",", " ", []string{"a", "b", "c"}},
{" a , b , c ", ",", " ", []string{"a", "b", "c"}},
{" a, b, c ", ",", " ", []string{"a", "b", "c"}},
{" a, ", ",", " ", []string{"a"}},
{" ", ",", " ", []string{}},
}
for _, tc := range testCases {
assert.Equal(t, tc.expected, splitAndTrimEmpty(tc.s, tc.sep, tc.cutset), "%s", tc.s)
}
}
func TestNodeDelayedStart(t *testing.T) {
config := cfg.ResetTestRoot("node_delayed_start_test")
defer os.RemoveAll(config.RootDir)

78
node/utils.go Normal file
View File

@@ -0,0 +1,78 @@
package node
import (
"strings"
"time"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/prometheus/common/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// splitAndTrimEmpty slices s into all subslices separated by sep and returns a
// slice of the string s with all leading and trailing Unicode code points
// contained in cutset removed. If sep is empty, SplitAndTrim splits after each
// UTF-8 sequence. First part is equivalent to strings.SplitN with a count of
// -1. also filter out empty strings, only return non-empty strings.
func splitAndTrimEmpty(s, sep, cutset string) []string {
if s == "" {
return []string{}
}
spl := strings.Split(s, sep)
nonEmptyStrings := make([]string, 0, len(spl))
for i := 0; i < len(spl); i++ {
element := strings.Trim(spl[i], cutset)
if element != "" {
nonEmptyStrings = append(nonEmptyStrings, element)
}
}
return nonEmptyStrings
}
// ConstructDialOptions constructs a list of grpc dial options
func ConstructDialOptions(
withCert string,
extraOpts ...grpc.DialOption,
) []grpc.DialOption {
var transportSecurity grpc.DialOption
if withCert != "" {
creds, err := credentials.NewClientTLSFromFile(withCert, "")
if err != nil {
log.Errorf("Could not get valid credentials: %v", err)
return nil
}
transportSecurity = grpc.WithTransportCredentials(creds)
} else {
transportSecurity = grpc.WithInsecure()
log.Warn("You are using an insecure gRPC connection! Please provide a certificate and key to use a secure connection.")
}
const (
retries = 50 // 50 * 100ms = 5s total
timeout = 100 * time.Millisecond
maxCallRecvMsgSize = 10 << 20 // Default 10Mb
)
opts := []grpc_retry.CallOption{
grpc_retry.WithBackoff(grpc_retry.BackoffExponential(timeout)),
}
dialOpts := []grpc.DialOption{
transportSecurity,
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxCallRecvMsgSize),
grpc_retry.WithMax(retries),
),
grpc.WithUnaryInterceptor(
grpc_retry.UnaryClientInterceptor(opts...),
),
}
for _, opt := range extraOpts {
dialOpts = append(dialOpts, opt)
}
return dialOpts
}

26
node/utils_test.go Normal file
View File

@@ -0,0 +1,26 @@
package node
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSplitAndTrimEmpty(t *testing.T) {
testCases := []struct {
s string
sep string
cutset string
expected []string
}{
{"a,b,c", ",", " ", []string{"a", "b", "c"}},
{" a , b , c ", ",", " ", []string{"a", "b", "c"}},
{" a, b, c ", ",", " ", []string{"a", "b", "c"}},
{" a, ", ",", " ", []string{"a"}},
{" ", ",", " ", []string{}},
}
for _, tc := range testCases {
assert.Equal(t, tc.expected, splitAndTrimEmpty(tc.s, tc.sep, tc.cutset), "%s", tc.s)
}
}

111
privval/client.go Normal file
View File

@@ -0,0 +1,111 @@
package privval
import (
"context"
"fmt"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/status"
"github.com/tendermint/tendermint/crypto"
cryptoenc "github.com/tendermint/tendermint/crypto/encoding"
"github.com/tendermint/tendermint/libs/log"
privvalproto "github.com/tendermint/tendermint/proto/privval"
tmproto "github.com/tendermint/tendermint/proto/types"
"github.com/tendermint/tendermint/types"
)
// SignerClient implements PrivValidator.
// Handles remote validator connections that provide signing services
type SignerClient struct {
ctx context.Context
privValidator privvalproto.PrivValidatorAPIClient
conn *grpc.ClientConn
logger log.Logger
}
var _ types.PrivValidator = (*SignerClient)(nil)
// NewSignerClient returns an instance of SignerClient.
// it will start the endpoint (if not already started)
func NewSignerClient(target string,
opts []grpc.DialOption, log log.Logger) (*SignerClient, error) {
if target == "" {
return nil, fmt.Errorf("target connection parameter missing. endpoint %s", target)
}
ctx := context.Background()
conn, err := grpc.DialContext(ctx, target, opts...)
if err != nil {
log.Error("unable to connect to client.", "target", target, "err", err)
}
sc := &SignerClient{
ctx: ctx,
privValidator: privvalproto.NewPrivValidatorAPIClient(conn), // Create the Private Validator Client
logger: log,
}
return sc, nil
}
// Close closes the underlying connection
func (sc *SignerClient) Close() error {
sc.logger.Info("Stopping service")
if sc.conn != nil {
return sc.conn.Close()
}
return nil
}
//--------------------------------------------------------
// Implement PrivValidator
// GetPubKey retrieves a public key from a remote signer
// returns an error if client is not able to provide the key
func (sc *SignerClient) GetPubKey() (crypto.PubKey, error) {
resp, err := sc.privValidator.GetPubKey(sc.ctx, &privvalproto.PubKeyRequest{})
if err != nil {
errStatus, _ := status.FromError(err)
sc.logger.Error("SignerClient::GetPubKey", "err", errStatus.Message())
return nil, fmt.Errorf("send GetPubKey request: %w", errStatus.Err())
}
pk, err := cryptoenc.PubKeyFromProto(*resp.PubKey)
if err != nil {
return nil, err
}
return pk, nil
}
// SignVote requests a remote signer to sign a vote
func (sc *SignerClient) SignVote(chainID string, vote *tmproto.Vote) error {
resp, err := sc.privValidator.SignVote(sc.ctx, &privvalproto.SignVoteRequest{ChainId: chainID, Vote: vote})
if err != nil {
errStatus, _ := status.FromError(err)
sc.logger.Error("Client SignVote", "err", errStatus.Message())
return fmt.Errorf("send SignVote request: %w", errStatus.Err())
}
*vote = *resp.Vote
return nil
}
// SignProposal requests a remote signer to sign a proposal
func (sc *SignerClient) SignProposal(chainID string, proposal *tmproto.Proposal) error {
resp, err := sc.privValidator.SignProposal(
sc.ctx, &privvalproto.SignProposalRequest{ChainId: chainID, Proposal: proposal})
if err != nil {
errStatus, _ := status.FromError(err)
sc.logger.Error("SignerClient::SignProposal", "err", errStatus.Message())
return fmt.Errorf("send SignProposal request: %w", errStatus.Err())
}
*proposal = *resp.Proposal
return nil
}

View File

@@ -1,14 +0,0 @@
package privval
import (
amino "github.com/tendermint/go-amino"
cryptoamino "github.com/tendermint/tendermint/crypto/encoding/amino"
)
var cdc = amino.NewCodec()
func init() {
cryptoamino.RegisterAmino(cdc)
RegisterRemoteSignerMsg(cdc)
}

View File

@@ -1,65 +0,0 @@
package privval
import (
amino "github.com/tendermint/go-amino"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/types"
)
// SignerMessage is sent between Signer Clients and Servers.
type SignerMessage interface{}
func RegisterRemoteSignerMsg(cdc *amino.Codec) {
cdc.RegisterInterface((*SignerMessage)(nil), nil)
cdc.RegisterConcrete(&PubKeyRequest{}, "tendermint/remotesigner/PubKeyRequest", nil)
cdc.RegisterConcrete(&PubKeyResponse{}, "tendermint/remotesigner/PubKeyResponse", nil)
cdc.RegisterConcrete(&SignVoteRequest{}, "tendermint/remotesigner/SignVoteRequest", nil)
cdc.RegisterConcrete(&SignedVoteResponse{}, "tendermint/remotesigner/SignedVoteResponse", nil)
cdc.RegisterConcrete(&SignProposalRequest{}, "tendermint/remotesigner/SignProposalRequest", nil)
cdc.RegisterConcrete(&SignedProposalResponse{}, "tendermint/remotesigner/SignedProposalResponse", nil)
cdc.RegisterConcrete(&PingRequest{}, "tendermint/remotesigner/PingRequest", nil)
cdc.RegisterConcrete(&PingResponse{}, "tendermint/remotesigner/PingResponse", nil)
}
// TODO: Add ChainIDRequest
// PubKeyRequest requests the consensus public key from the remote signer.
type PubKeyRequest struct{}
// PubKeyResponse is a response message containing the public key.
type PubKeyResponse struct {
PubKey crypto.PubKey
Error *RemoteSignerError
}
// SignVoteRequest is a request to sign a vote
type SignVoteRequest struct {
Vote *types.Vote
}
// SignedVoteResponse is a response containing a signed vote or an error
type SignedVoteResponse struct {
Vote *types.Vote
Error *RemoteSignerError
}
// SignProposalRequest is a request to sign a proposal
type SignProposalRequest struct {
Proposal *types.Proposal
}
// SignedProposalResponse is response containing a signed proposal or an error
type SignedProposalResponse struct {
Proposal *types.Proposal
Error *RemoteSignerError
}
// PingRequest is a request to confirm that the connection is alive.
type PingRequest struct {
}
// PingResponse is a response to confirm that the connection is alive.
type PingResponse struct {
}

View File

@@ -1,83 +0,0 @@
package privval
import (
"fmt"
"time"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/types"
)
// RetrySignerClient wraps SignerClient adding retry for each operation (except
// Ping) w/ a timeout.
type RetrySignerClient struct {
next *SignerClient
retries int
timeout time.Duration
}
// NewRetrySignerClient returns RetrySignerClient. If +retries+ is 0, the
// client will be retrying each operation indefinitely.
func NewRetrySignerClient(sc *SignerClient, retries int, timeout time.Duration) *RetrySignerClient {
return &RetrySignerClient{sc, retries, timeout}
}
var _ types.PrivValidator = (*RetrySignerClient)(nil)
func (sc *RetrySignerClient) Close() error {
return sc.next.Close()
}
func (sc *RetrySignerClient) IsConnected() bool {
return sc.next.IsConnected()
}
func (sc *RetrySignerClient) WaitForConnection(maxWait time.Duration) error {
return sc.next.WaitForConnection(maxWait)
}
//--------------------------------------------------------
// Implement PrivValidator
func (sc *RetrySignerClient) Ping() error {
return sc.next.Ping()
}
func (sc *RetrySignerClient) GetPubKey() (crypto.PubKey, error) {
var (
pk crypto.PubKey
err error
)
for i := 0; i < sc.retries || sc.retries == 0; i++ {
pk, err = sc.next.GetPubKey()
if err == nil {
return pk, nil
}
time.Sleep(sc.timeout)
}
return nil, fmt.Errorf("exhausted all attempts to get pubkey: %w", err)
}
func (sc *RetrySignerClient) SignVote(chainID string, vote *types.Vote) error {
var err error
for i := 0; i < sc.retries || sc.retries == 0; i++ {
err = sc.next.SignVote(chainID, vote)
if err == nil {
return nil
}
time.Sleep(sc.timeout)
}
return fmt.Errorf("exhausted all attempts to sign vote: %w", err)
}
func (sc *RetrySignerClient) SignProposal(chainID string, proposal *types.Proposal) error {
var err error
for i := 0; i < sc.retries || sc.retries == 0; i++ {
err = sc.next.SignProposal(chainID, proposal)
if err == nil {
return nil
}
time.Sleep(sc.timeout)
}
return fmt.Errorf("exhausted all attempts to sign proposal: %w", err)
}

114
privval/server.go Normal file
View File

@@ -0,0 +1,114 @@
package privval
import (
context "context"
"net"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/tendermint/tendermint/crypto"
cryptoenc "github.com/tendermint/tendermint/crypto/encoding"
"github.com/tendermint/tendermint/libs/log"
tmnet "github.com/tendermint/tendermint/libs/net"
"github.com/tendermint/tendermint/libs/service"
privvalproto "github.com/tendermint/tendermint/proto/privval"
"github.com/tendermint/tendermint/types"
)
type SignerServer struct {
service.BaseService
Logger log.Logger
target string
ChainID string
PrivVal types.PrivValidator
Opts []grpc.ServerOption
Srv *grpc.Server
}
func NewSignerServer(target string, chainID string, privVal types.PrivValidator, log log.Logger, opts []grpc.ServerOption) *SignerServer {
return &SignerServer{
Logger: log,
target: target,
ChainID: chainID,
Opts: opts,
PrivVal: privVal,
}
}
// OnStart implements service.Service.
func (ss *SignerServer) OnStart() error {
protocol, address := tmnet.ProtocolAndAddress(ss.target)
lis, err := net.Listen(protocol, address)
if err != nil {
ss.Logger.Error("failed to listen: ", "err", err)
}
s := grpc.NewServer(ss.Opts...)
ss.Srv = s
privvalproto.RegisterPrivValidatorAPIServer(ss.Srv, &SignerServer{})
if err := ss.Srv.Serve(lis); err != nil {
ss.Logger.Error("failed to serve:", "err", err)
}
return nil
}
// OnStop implements service.Service.
func (ss *SignerServer) OnStop() {
ss.Logger.Debug("SignerServer: OnStop calling Close")
ss.Srv.GracefulStop()
}
var _ privvalproto.PrivValidatorAPIServer = (*SignerServer)(nil)
// PubKey receives a request for the pubkey
// returns the pubkey on success and error on failure
func (ss *SignerServer) GetPubKey(ctx context.Context, req *privvalproto.PubKeyRequest) (
*privvalproto.PubKeyResponse, error) {
var pubKey crypto.PubKey
pubKey, err := ss.PrivVal.GetPubKey()
if err != nil {
return nil, status.Errorf(codes.NotFound, "error getting pubkey: %v", err)
}
pk, err := cryptoenc.PubKeyToProto(pubKey)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "error transistioning pubkey to proto: %v", err)
}
return &privvalproto.PubKeyResponse{PubKey: &pk}, nil
}
// SignVote receives a vote sign requests, attempts to sign it
// returns SignedVoteResponse on success and error on failure
func (ss *SignerServer) SignVote(ctx context.Context, req *privvalproto.SignVoteRequest) (
*privvalproto.SignedVoteResponse, error) {
vote := req.Vote
err := ss.PrivVal.SignVote(req.ChainId, vote)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "error signing vote: %v", err)
}
return &privvalproto.SignedVoteResponse{Vote: vote}, nil
}
// SignProposal receives a proposal sign requests, attempts to sign it
// returns SignedProposalResponse on success and error on failure
func (ss *SignerServer) SignProposal(ctx context.Context, req *privvalproto.SignProposalRequest) (
*privvalproto.SignedProposalResponse, error) {
proposal := req.Proposal
err := ss.PrivVal.SignProposal(req.ChainId, proposal)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "error signing proposal: %v", err)
}
return &privvalproto.SignedProposalResponse{Proposal: proposal}, nil
}

1
privval/server_test.go Normal file
View File

@@ -0,0 +1 @@
package privval

View File

@@ -1,131 +0,0 @@
package privval
import (
"fmt"
"time"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/types"
)
// SignerClient implements PrivValidator.
// Handles remote validator connections that provide signing services
type SignerClient struct {
endpoint *SignerListenerEndpoint
}
var _ types.PrivValidator = (*SignerClient)(nil)
// NewSignerClient returns an instance of SignerClient.
// it will start the endpoint (if not already started)
func NewSignerClient(endpoint *SignerListenerEndpoint) (*SignerClient, error) {
if !endpoint.IsRunning() {
if err := endpoint.Start(); err != nil {
return nil, fmt.Errorf("failed to start listener endpoint: %w", err)
}
}
return &SignerClient{endpoint: endpoint}, nil
}
// Close closes the underlying connection
func (sc *SignerClient) Close() error {
return sc.endpoint.Close()
}
// IsConnected indicates with the signer is connected to a remote signing service
func (sc *SignerClient) IsConnected() bool {
return sc.endpoint.IsConnected()
}
// WaitForConnection waits maxWait for a connection or returns a timeout error
func (sc *SignerClient) WaitForConnection(maxWait time.Duration) error {
return sc.endpoint.WaitForConnection(maxWait)
}
//--------------------------------------------------------
// Implement PrivValidator
// Ping sends a ping request to the remote signer
func (sc *SignerClient) Ping() error {
response, err := sc.endpoint.SendRequest(&PingRequest{})
if err != nil {
sc.endpoint.Logger.Error("SignerClient::Ping", "err", err)
return nil
}
_, ok := response.(*PingResponse)
if !ok {
sc.endpoint.Logger.Error("SignerClient::Ping", "err", "response != PingResponse")
return err
}
return nil
}
// GetPubKey retrieves a public key from a remote signer
// returns an error if client is not able to provide the key
func (sc *SignerClient) GetPubKey() (crypto.PubKey, error) {
response, err := sc.endpoint.SendRequest(&PubKeyRequest{})
if err != nil {
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", err)
return nil, fmt.Errorf("send: %w", err)
}
pubKeyResp, ok := response.(*PubKeyResponse)
if !ok {
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != PubKeyResponse")
return nil, fmt.Errorf("unexpected response type %T", response)
}
if pubKeyResp.Error != nil {
sc.endpoint.Logger.Error("failed to get private validator's public key", "err", pubKeyResp.Error)
return nil, fmt.Errorf("remote error: %w", pubKeyResp.Error)
}
return pubKeyResp.PubKey, nil
}
// SignVote requests a remote signer to sign a vote
func (sc *SignerClient) SignVote(chainID string, vote *types.Vote) error {
response, err := sc.endpoint.SendRequest(&SignVoteRequest{Vote: vote})
if err != nil {
sc.endpoint.Logger.Error("SignerClient::SignVote", "err", err)
return err
}
resp, ok := response.(*SignedVoteResponse)
if !ok {
sc.endpoint.Logger.Error("SignerClient::GetPubKey", "err", "response != SignedVoteResponse")
return ErrUnexpectedResponse
}
if resp.Error != nil {
return resp.Error
}
*vote = *resp.Vote
return nil
}
// SignProposal requests a remote signer to sign a proposal
func (sc *SignerClient) SignProposal(chainID string, proposal *types.Proposal) error {
response, err := sc.endpoint.SendRequest(&SignProposalRequest{Proposal: proposal})
if err != nil {
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", err)
return err
}
resp, ok := response.(*SignedProposalResponse)
if !ok {
sc.endpoint.Logger.Error("SignerClient::SignProposal", "err", "response != SignedProposalResponse")
return ErrUnexpectedResponse
}
if resp.Error != nil {
return resp.Error
}
*proposal = *resp.Proposal
return nil
}

View File

@@ -1,263 +0,0 @@
package privval
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tmrand "github.com/tendermint/tendermint/libs/rand"
tmproto "github.com/tendermint/tendermint/proto/types"
"github.com/tendermint/tendermint/types"
)
type signerTestCase struct {
chainID string
mockPV types.PrivValidator
signerClient *SignerClient
signerServer *SignerServer
}
func getSignerTestCases(t *testing.T) []signerTestCase {
testCases := make([]signerTestCase, 0)
// Get test cases for each possible dialer (DialTCP / DialUnix / etc)
for _, dtc := range getDialerTestCases(t) {
chainID := tmrand.Str(12)
mockPV := types.NewMockPV()
// get a pair of signer listener, signer dialer endpoints
sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer)
sc, err := NewSignerClient(sl)
require.NoError(t, err)
ss := NewSignerServer(sd, chainID, mockPV)
err = ss.Start()
require.NoError(t, err)
tc := signerTestCase{
chainID: chainID,
mockPV: mockPV,
signerClient: sc,
signerServer: ss,
}
testCases = append(testCases, tc)
}
return testCases
}
func TestSignerClose(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
err := tc.signerClient.Close()
assert.NoError(t, err)
err = tc.signerServer.Stop()
assert.NoError(t, err)
}
}
func TestSignerPing(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
err := tc.signerClient.Ping()
assert.NoError(t, err)
}
}
func TestSignerGetPubKey(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
pubKey, err := tc.signerClient.GetPubKey()
require.NoError(t, err)
expectedPubKey, err := tc.mockPV.GetPubKey()
require.NoError(t, err)
assert.Equal(t, expectedPubKey, pubKey)
pubKey, err = tc.signerClient.GetPubKey()
require.NoError(t, err)
expectedpk, err := tc.mockPV.GetPubKey()
require.NoError(t, err)
expectedAddr := expectedpk.Address()
assert.Equal(t, expectedAddr, pubKey.Address())
}
}
func TestSignerProposal(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
ts := time.Now()
want := &types.Proposal{Timestamp: ts}
have := &types.Proposal{Timestamp: ts}
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want))
require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
}
func TestSignerVote(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
ts := time.Now()
want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
have := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want))
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
}
func TestSignerVoteResetDeadline(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
ts := time.Now()
want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
have := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
time.Sleep(testTimeoutReadWrite2o3)
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want))
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have))
assert.Equal(t, want.Signature, have.Signature)
// TODO(jleni): Clarify what is actually being tested
// This would exceed the deadline if it was not extended by the previous message
time.Sleep(testTimeoutReadWrite2o3)
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want))
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
}
func TestSignerVoteKeepAlive(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
ts := time.Now()
want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
have := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
// Check that even if the client does not request a
// signature for a long time. The service is still available
// in this particular case, we use the dialer logger to ensure that
// test messages are properly interleaved in the test logs
tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------")
time.Sleep(testTimeoutReadWrite * 3)
tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------")
require.NoError(t, tc.mockPV.SignVote(tc.chainID, want))
require.NoError(t, tc.signerClient.SignVote(tc.chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
}
func TestSignerSignProposalErrors(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
// Replace service with a mock that always fails
tc.signerServer.privVal = types.NewErroringMockPV()
tc.mockPV = types.NewErroringMockPV()
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
ts := time.Now()
proposal := &types.Proposal{Timestamp: ts}
err := tc.signerClient.SignProposal(tc.chainID, proposal)
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
err = tc.mockPV.SignProposal(tc.chainID, proposal)
require.Error(t, err)
err = tc.signerClient.SignProposal(tc.chainID, proposal)
require.Error(t, err)
}
}
func TestSignerSignVoteErrors(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
ts := time.Now()
vote := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
// Replace signer service privval with one that always fails
tc.signerServer.privVal = types.NewErroringMockPV()
tc.mockPV = types.NewErroringMockPV()
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
err := tc.signerClient.SignVote(tc.chainID, vote)
require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error())
err = tc.mockPV.SignVote(tc.chainID, vote)
require.Error(t, err)
err = tc.signerClient.SignVote(tc.chainID, vote)
require.Error(t, err)
}
}
func brokenHandler(privVal types.PrivValidator, request SignerMessage, chainID string) (SignerMessage, error) {
var res SignerMessage
var err error
switch r := request.(type) {
// This is broken and will answer most requests with a pubkey response
case *PubKeyRequest:
res = &PubKeyResponse{nil, nil}
case *SignVoteRequest:
res = &PubKeyResponse{nil, nil}
case *SignProposalRequest:
res = &PubKeyResponse{nil, nil}
case *PingRequest:
err, res = nil, &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
return res, err
}
func TestSignerUnexpectedResponse(t *testing.T) {
for _, tc := range getSignerTestCases(t) {
tc.signerServer.privVal = types.NewMockPV()
tc.mockPV = types.NewMockPV()
tc.signerServer.SetRequestHandler(brokenHandler)
defer tc.signerServer.Stop()
defer tc.signerClient.Close()
ts := time.Now()
want := &types.Vote{Timestamp: ts, Type: tmproto.PrecommitType}
e := tc.signerClient.SignVote(tc.chainID, want)
assert.EqualError(t, e, "received unexpected response")
}
}

View File

@@ -1,89 +0,0 @@
package privval
import (
"time"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/libs/service"
)
const (
defaultMaxDialRetries = 10
defaultRetryWaitMilliseconds = 100
)
// SignerServiceEndpointOption sets an optional parameter on the SignerDialerEndpoint.
type SignerServiceEndpointOption func(*SignerDialerEndpoint)
// SignerDialerEndpointTimeoutReadWrite sets the read and write timeout for connections
// from external signing processes.
func SignerDialerEndpointTimeoutReadWrite(timeout time.Duration) SignerServiceEndpointOption {
return func(ss *SignerDialerEndpoint) { ss.timeoutReadWrite = timeout }
}
// SignerDialerEndpointConnRetries sets the amount of attempted retries to acceptNewConnection.
func SignerDialerEndpointConnRetries(retries int) SignerServiceEndpointOption {
return func(ss *SignerDialerEndpoint) { ss.maxConnRetries = retries }
}
// SignerDialerEndpointRetryWaitInterval sets the retry wait interval to a custom value
func SignerDialerEndpointRetryWaitInterval(interval time.Duration) SignerServiceEndpointOption {
return func(ss *SignerDialerEndpoint) { ss.retryWait = interval }
}
// SignerDialerEndpoint dials using its dialer and responds to any
// signature requests using its privVal.
type SignerDialerEndpoint struct {
signerEndpoint
dialer SocketDialer
retryWait time.Duration
maxConnRetries int
}
// NewSignerDialerEndpoint returns a SignerDialerEndpoint that will dial using the given
// dialer and respond to any signature requests over the connection
// using the given privVal.
func NewSignerDialerEndpoint(
logger log.Logger,
dialer SocketDialer,
) *SignerDialerEndpoint {
sd := &SignerDialerEndpoint{
dialer: dialer,
retryWait: defaultRetryWaitMilliseconds * time.Millisecond,
maxConnRetries: defaultMaxDialRetries,
}
sd.BaseService = *service.NewBaseService(logger, "SignerDialerEndpoint", sd)
sd.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second
return sd
}
func (sd *SignerDialerEndpoint) ensureConnection() error {
if sd.IsConnected() {
return nil
}
retries := 0
for retries < sd.maxConnRetries {
conn, err := sd.dialer()
if err != nil {
retries++
sd.Logger.Debug("SignerDialer: Reconnection failed", "retries", retries, "max", sd.maxConnRetries, "err", err)
// Wait between retries
time.Sleep(sd.retryWait)
} else {
sd.SetConnection(conn)
sd.Logger.Debug("SignerDialer: Connection Ready")
return nil
}
}
sd.Logger.Debug("SignerDialer: Max retries exceeded", "retries", retries, "max", sd.maxConnRetries)
return ErrNoConnection
}

View File

@@ -1,152 +0,0 @@
package privval
import (
"fmt"
"net"
"sync"
"time"
"github.com/tendermint/tendermint/libs/service"
)
const (
defaultTimeoutReadWriteSeconds = 3
)
type signerEndpoint struct {
service.BaseService
connMtx sync.Mutex
conn net.Conn
timeoutReadWrite time.Duration
}
// Close closes the underlying net.Conn.
func (se *signerEndpoint) Close() error {
se.DropConnection()
return nil
}
// IsConnected indicates if there is an active connection
func (se *signerEndpoint) IsConnected() bool {
se.connMtx.Lock()
defer se.connMtx.Unlock()
return se.isConnected()
}
// TryGetConnection retrieves a connection if it is already available
func (se *signerEndpoint) GetAvailableConnection(connectionAvailableCh chan net.Conn) bool {
se.connMtx.Lock()
defer se.connMtx.Unlock()
// Is there a connection ready?
select {
case se.conn = <-connectionAvailableCh:
return true
default:
}
return false
}
// TryGetConnection retrieves a connection if it is already available
func (se *signerEndpoint) WaitConnection(connectionAvailableCh chan net.Conn, maxWait time.Duration) error {
se.connMtx.Lock()
defer se.connMtx.Unlock()
select {
case se.conn = <-connectionAvailableCh:
case <-time.After(maxWait):
return ErrConnectionTimeout
}
return nil
}
// SetConnection replaces the current connection object
func (se *signerEndpoint) SetConnection(newConnection net.Conn) {
se.connMtx.Lock()
defer se.connMtx.Unlock()
se.conn = newConnection
}
// IsConnected indicates if there is an active connection
func (se *signerEndpoint) DropConnection() {
se.connMtx.Lock()
defer se.connMtx.Unlock()
se.dropConnection()
}
// ReadMessage reads a message from the endpoint
func (se *signerEndpoint) ReadMessage() (msg SignerMessage, err error) {
se.connMtx.Lock()
defer se.connMtx.Unlock()
if !se.isConnected() {
return nil, fmt.Errorf("endpoint is not connected")
}
// Reset read deadline
deadline := time.Now().Add(se.timeoutReadWrite)
err = se.conn.SetReadDeadline(deadline)
if err != nil {
return
}
const maxRemoteSignerMsgSize = 1024 * 10
_, err = cdc.UnmarshalBinaryLengthPrefixedReader(se.conn, &msg, maxRemoteSignerMsgSize)
if _, ok := err.(timeoutError); ok {
if err != nil {
err = fmt.Errorf("%v: %w", err, ErrReadTimeout)
} else {
err = fmt.Errorf("empty error: %w", ErrReadTimeout)
}
se.Logger.Debug("Dropping [read]", "obj", se)
se.dropConnection()
}
return
}
// WriteMessage writes a message from the endpoint
func (se *signerEndpoint) WriteMessage(msg SignerMessage) (err error) {
se.connMtx.Lock()
defer se.connMtx.Unlock()
if !se.isConnected() {
return fmt.Errorf("endpoint is not connected: %w", ErrNoConnection)
}
// Reset read deadline
deadline := time.Now().Add(se.timeoutReadWrite)
err = se.conn.SetWriteDeadline(deadline)
if err != nil {
return
}
_, err = cdc.MarshalBinaryLengthPrefixedWriter(se.conn, msg)
if _, ok := err.(timeoutError); ok {
if err != nil {
err = fmt.Errorf("%v: %w", err, ErrWriteTimeout)
} else {
err = fmt.Errorf("empty error: %w", ErrWriteTimeout)
}
se.dropConnection()
}
return
}
func (se *signerEndpoint) isConnected() bool {
return se.conn != nil
}
func (se *signerEndpoint) dropConnection() {
if se.conn != nil {
if err := se.conn.Close(); err != nil {
se.Logger.Error("signerEndpoint::dropConnection", "err", err)
}
se.conn = nil
}
}

View File

@@ -1,198 +0,0 @@
package privval
import (
"fmt"
"net"
"sync"
"time"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/libs/service"
)
// SignerValidatorEndpointOption sets an optional parameter on the SocketVal.
type SignerValidatorEndpointOption func(*SignerListenerEndpoint)
// SignerListenerEndpoint listens for an external process to dial in
// and keeps the connection alive by dropping and reconnecting
type SignerListenerEndpoint struct {
signerEndpoint
listener net.Listener
connectRequestCh chan struct{}
connectionAvailableCh chan net.Conn
timeoutAccept time.Duration
pingTimer *time.Ticker
instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest
}
// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
func NewSignerListenerEndpoint(
logger log.Logger,
listener net.Listener,
) *SignerListenerEndpoint {
sc := &SignerListenerEndpoint{
listener: listener,
timeoutAccept: defaultTimeoutAcceptSeconds * time.Second,
}
sc.BaseService = *service.NewBaseService(logger, "SignerListenerEndpoint", sc)
sc.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second
return sc
}
// OnStart implements service.Service.
func (sl *SignerListenerEndpoint) OnStart() error {
sl.connectRequestCh = make(chan struct{})
sl.connectionAvailableCh = make(chan net.Conn)
sl.pingTimer = time.NewTicker(defaultPingPeriodMilliseconds * time.Millisecond)
go sl.serviceLoop()
go sl.pingLoop()
sl.connectRequestCh <- struct{}{}
return nil
}
// OnStop implements service.Service
func (sl *SignerListenerEndpoint) OnStop() {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
_ = sl.Close()
// Stop listening
if sl.listener != nil {
if err := sl.listener.Close(); err != nil {
sl.Logger.Error("Closing Listener", "err", err)
sl.listener = nil
}
}
sl.pingTimer.Stop()
}
// WaitForConnection waits maxWait for a connection or returns a timeout error
func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
return sl.ensureConnection(maxWait)
}
// SendRequest ensures there is a connection, sends a request and waits for a response
func (sl *SignerListenerEndpoint) SendRequest(request SignerMessage) (SignerMessage, error) {
sl.instanceMtx.Lock()
defer sl.instanceMtx.Unlock()
err := sl.ensureConnection(sl.timeoutAccept)
if err != nil {
return nil, err
}
err = sl.WriteMessage(request)
if err != nil {
return nil, err
}
res, err := sl.ReadMessage()
if err != nil {
return nil, err
}
return res, nil
}
func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error {
if sl.IsConnected() {
return nil
}
// Is there a connection ready? then use it
if sl.GetAvailableConnection(sl.connectionAvailableCh) {
return nil
}
// block until connected or timeout
sl.triggerConnect()
err := sl.WaitConnection(sl.connectionAvailableCh, maxWait)
if err != nil {
return err
}
return nil
}
func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) {
if !sl.IsRunning() || sl.listener == nil {
return nil, fmt.Errorf("endpoint is closing")
}
// wait for a new conn
sl.Logger.Info("SignerListener: Listening for new connection")
conn, err := sl.listener.Accept()
if err != nil {
return nil, err
}
return conn, nil
}
func (sl *SignerListenerEndpoint) triggerConnect() {
select {
case sl.connectRequestCh <- struct{}{}:
default:
}
}
func (sl *SignerListenerEndpoint) triggerReconnect() {
sl.DropConnection()
sl.triggerConnect()
}
func (sl *SignerListenerEndpoint) serviceLoop() {
for {
select {
case <-sl.connectRequestCh:
{
conn, err := sl.acceptNewConnection()
if err == nil {
sl.Logger.Info("SignerListener: Connected")
// We have a good connection, wait for someone that needs one otherwise cancellation
select {
case sl.connectionAvailableCh <- conn:
case <-sl.Quit():
return
}
}
select {
case sl.connectRequestCh <- struct{}{}:
default:
}
}
case <-sl.Quit():
return
}
}
}
func (sl *SignerListenerEndpoint) pingLoop() {
for {
select {
case <-sl.pingTimer.C:
{
_, err := sl.SendRequest(&PingRequest{})
if err != nil {
sl.Logger.Error("SignerListener: Ping timeout")
sl.triggerReconnect()
}
}
case <-sl.Quit():
return
}
}
}

View File

@@ -1,199 +0,0 @@
package privval
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519"
"github.com/tendermint/tendermint/libs/log"
tmnet "github.com/tendermint/tendermint/libs/net"
tmrand "github.com/tendermint/tendermint/libs/rand"
"github.com/tendermint/tendermint/types"
)
var (
testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second
testTimeoutReadWrite = 100 * time.Millisecond
testTimeoutReadWrite2o3 = 60 * time.Millisecond // 2/3 of the other one
)
type dialerTestCase struct {
addr string
dialer SocketDialer
}
// TestSignerRemoteRetryTCPOnly will test connection retry attempts over TCP. We
// don't need this for Unix sockets because the OS instantly knows the state of
// both ends of the socket connection. This basically causes the
// SignerDialerEndpoint.dialer() call inside SignerDialerEndpoint.acceptNewConnection() to return
// successfully immediately, putting an instant stop to any retry attempts.
func TestSignerRemoteRetryTCPOnly(t *testing.T) {
var (
attemptCh = make(chan int)
retries = 10
)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
// Continuously Accept connection and close {attempts} times
go func(ln net.Listener, attemptCh chan<- int) {
attempts := 0
for {
conn, err := ln.Accept()
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
attempts++
if attempts == retries {
attemptCh <- attempts
break
}
}
}(ln, attemptCh)
dialerEndpoint := NewSignerDialerEndpoint(
log.TestingLogger(),
DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, ed25519.GenPrivKey()),
)
SignerDialerEndpointTimeoutReadWrite(time.Millisecond)(dialerEndpoint)
SignerDialerEndpointConnRetries(retries)(dialerEndpoint)
chainID := tmrand.Str(12)
mockPV := types.NewMockPV()
signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
err = signerServer.Start()
require.NoError(t, err)
defer signerServer.Stop()
select {
case attempts := <-attemptCh:
assert.Equal(t, retries, attempts)
case <-time.After(1500 * time.Millisecond):
t.Error("expected remote to observe connection attempts")
}
}
func TestRetryConnToRemoteSigner(t *testing.T) {
for _, tc := range getDialerTestCases(t) {
var (
logger = log.TestingLogger()
chainID = tmrand.Str(12)
mockPV = types.NewMockPV()
endpointIsOpenCh = make(chan struct{})
thisConnTimeout = testTimeoutReadWrite
listenerEndpoint = newSignerListenerEndpoint(logger, tc.addr, thisConnTimeout)
)
dialerEndpoint := NewSignerDialerEndpoint(
logger,
tc.dialer,
)
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
SignerDialerEndpointConnRetries(10)(dialerEndpoint)
signerServer := NewSignerServer(dialerEndpoint, chainID, mockPV)
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
defer listenerEndpoint.Stop()
require.NoError(t, signerServer.Start())
assert.True(t, signerServer.IsRunning())
<-endpointIsOpenCh
signerServer.Stop()
dialerEndpoint2 := NewSignerDialerEndpoint(
logger,
tc.dialer,
)
signerServer2 := NewSignerServer(dialerEndpoint2, chainID, mockPV)
// let some pings pass
require.NoError(t, signerServer2.Start())
assert.True(t, signerServer2.IsRunning())
defer signerServer2.Stop()
// give the client some time to re-establish the conn to the remote signer
// should see sth like this in the logs:
//
// E[10016-01-10|17:12:46.128] Ping err="remote signer timed out"
// I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal
time.Sleep(testTimeoutReadWrite * 2)
}
}
///////////////////////////////////
func newSignerListenerEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerListenerEndpoint {
proto, address := tmnet.ProtocolAndAddress(addr)
ln, err := net.Listen(proto, address)
logger.Info("SignerListener: Listening", "proto", proto, "address", address)
if err != nil {
panic(err)
}
var listener net.Listener
if proto == "unix" {
unixLn := NewUnixListener(ln)
UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn)
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
listener = unixLn
} else {
tcpLn := NewTCPListener(ln, ed25519.GenPrivKey())
TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn)
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
listener = tcpLn
}
return NewSignerListenerEndpoint(logger, listener)
}
func startListenerEndpointAsync(t *testing.T, sle *SignerListenerEndpoint, endpointIsOpenCh chan struct{}) {
go func(sle *SignerListenerEndpoint) {
require.NoError(t, sle.Start())
assert.True(t, sle.IsRunning())
close(endpointIsOpenCh)
}(sle)
}
func getMockEndpoints(
t *testing.T,
addr string,
socketDialer SocketDialer,
) (*SignerListenerEndpoint, *SignerDialerEndpoint) {
var (
logger = log.TestingLogger()
endpointIsOpenCh = make(chan struct{})
dialerEndpoint = NewSignerDialerEndpoint(
logger,
socketDialer,
)
listenerEndpoint = newSignerListenerEndpoint(logger, addr, testTimeoutReadWrite)
)
SignerDialerEndpointTimeoutReadWrite(testTimeoutReadWrite)(dialerEndpoint)
SignerDialerEndpointConnRetries(1e6)(dialerEndpoint)
startListenerEndpointAsync(t, listenerEndpoint, endpointIsOpenCh)
require.NoError(t, dialerEndpoint.Start())
assert.True(t, dialerEndpoint.IsRunning())
<-endpointIsOpenCh
return listenerEndpoint, dialerEndpoint
}

View File

@@ -1,52 +0,0 @@
package privval
import (
"fmt"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/types"
)
func DefaultValidationRequestHandler(
privVal types.PrivValidator,
req SignerMessage,
chainID string,
) (SignerMessage, error) {
var res SignerMessage
var err error
switch r := req.(type) {
case *PubKeyRequest:
var pubKey crypto.PubKey
pubKey, err = privVal.GetPubKey()
if err != nil {
res = &PubKeyResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &PubKeyResponse{pubKey, nil}
}
case *SignVoteRequest:
err = privVal.SignVote(chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = privVal.SignProposal(chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
case *PingRequest:
err, res = nil, &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
return res, err
}

View File

@@ -1,107 +0,0 @@
package privval
import (
"io"
"sync"
"github.com/tendermint/tendermint/libs/service"
"github.com/tendermint/tendermint/types"
)
// ValidationRequestHandlerFunc handles different remoteSigner requests
type ValidationRequestHandlerFunc func(
privVal types.PrivValidator,
requestMessage SignerMessage,
chainID string) (SignerMessage, error)
type SignerServer struct {
service.BaseService
endpoint *SignerDialerEndpoint
chainID string
privVal types.PrivValidator
handlerMtx sync.Mutex
validationRequestHandler ValidationRequestHandlerFunc
}
func NewSignerServer(endpoint *SignerDialerEndpoint, chainID string, privVal types.PrivValidator) *SignerServer {
ss := &SignerServer{
endpoint: endpoint,
chainID: chainID,
privVal: privVal,
validationRequestHandler: DefaultValidationRequestHandler,
}
ss.BaseService = *service.NewBaseService(endpoint.Logger, "SignerServer", ss)
return ss
}
// OnStart implements service.Service.
func (ss *SignerServer) OnStart() error {
go ss.serviceLoop()
return nil
}
// OnStop implements service.Service.
func (ss *SignerServer) OnStop() {
ss.endpoint.Logger.Debug("SignerServer: OnStop calling Close")
_ = ss.endpoint.Close()
}
// SetRequestHandler override the default function that is used to service requests
func (ss *SignerServer) SetRequestHandler(validationRequestHandler ValidationRequestHandlerFunc) {
ss.handlerMtx.Lock()
defer ss.handlerMtx.Unlock()
ss.validationRequestHandler = validationRequestHandler
}
func (ss *SignerServer) servicePendingRequest() {
if !ss.IsRunning() {
return // Ignore error from closing.
}
req, err := ss.endpoint.ReadMessage()
if err != nil {
if err != io.EOF {
ss.Logger.Error("SignerServer: HandleMessage", "err", err)
}
return
}
var res SignerMessage
{
// limit the scope of the lock
ss.handlerMtx.Lock()
defer ss.handlerMtx.Unlock()
res, err = ss.validationRequestHandler(ss.privVal, req, ss.chainID)
if err != nil {
// only log the error; we'll reply with an error in res
ss.Logger.Error("SignerServer: handleMessage", "err", err)
}
}
if res != nil {
err = ss.endpoint.WriteMessage(res)
if err != nil {
ss.Logger.Error("SignerServer: writeMessage", "err", err)
}
}
}
func (ss *SignerServer) serviceLoop() {
for {
select {
default:
err := ss.endpoint.ensureConnection()
if err != nil {
return
}
ss.servicePendingRequest()
case <-ss.Quit():
return
}
}
}

View File

@@ -1,43 +0,0 @@
package privval
import (
"errors"
"net"
"time"
"github.com/tendermint/tendermint/crypto"
tmnet "github.com/tendermint/tendermint/libs/net"
p2pconn "github.com/tendermint/tendermint/p2p/conn"
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("dialed maximum retries")
)
// SocketDialer dials a remote address and returns a net.Conn or an error.
type SocketDialer func() (net.Conn, error)
// DialTCPFn dials the given tcp addr, using the given timeoutReadWrite and
// privKey for the authenticated encryption handshake.
func DialTCPFn(addr string, timeoutReadWrite time.Duration, privKey crypto.PrivKey) SocketDialer {
return func() (net.Conn, error) {
conn, err := tmnet.Connect(addr)
if err == nil {
deadline := time.Now().Add(timeoutReadWrite)
err = conn.SetDeadline(deadline)
}
if err == nil {
conn, err = p2pconn.MakeSecretConnection(conn, privKey)
}
return conn, err
}
}
// DialUnixFn dials the given unix socket.
func DialUnixFn(addr string) SocketDialer {
return func() (net.Conn, error) {
unixAddr := &net.UnixAddr{Name: addr, Net: "unix"}
return net.DialUnix("unix", nil, unixAddr)
}
}

View File

@@ -1,48 +0,0 @@
package privval
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519"
)
func getDialerTestCases(t *testing.T) []dialerTestCase {
tcpAddr := GetFreeLocalhostAddrPort()
unixFilePath, err := testUnixAddr()
require.NoError(t, err)
unixAddr := fmt.Sprintf("unix://%s", unixFilePath)
return []dialerTestCase{
{
addr: tcpAddr,
dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()),
},
{
addr: unixAddr,
dialer: DialUnixFn(unixFilePath),
},
}
}
func TestIsConnTimeoutForFundamentalTimeouts(t *testing.T) {
// Generate a networking timeout
tcpAddr := GetFreeLocalhostAddrPort()
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
_, err := dialer()
assert.Error(t, err)
assert.True(t, IsConnTimeout(err))
}
func TestIsConnTimeoutForWrappedConnTimeouts(t *testing.T) {
tcpAddr := GetFreeLocalhostAddrPort()
dialer := DialTCPFn(tcpAddr, time.Millisecond, ed25519.GenPrivKey())
_, err := dialer()
assert.Error(t, err)
err = fmt.Errorf("%v: %w", err, ErrConnectionTimeout)
assert.True(t, IsConnTimeout(err))
}

View File

@@ -1,191 +0,0 @@
package privval
import (
"net"
"time"
"github.com/tendermint/tendermint/crypto/ed25519"
p2pconn "github.com/tendermint/tendermint/p2p/conn"
)
const (
defaultTimeoutAcceptSeconds = 3
defaultPingPeriodMilliseconds = 100
)
// timeoutError can be used to check if an error returned from the netp package
// was due to a timeout.
type timeoutError interface {
Timeout() bool
}
//------------------------------------------------------------------
// TCP Listener
// TCPListenerOption sets an optional parameter on the tcpListener.
type TCPListenerOption func(*TCPListener)
// TCPListenerTimeoutAccept sets the timeout for the listener.
// A zero time value disables the timeout.
func TCPListenerTimeoutAccept(timeout time.Duration) TCPListenerOption {
return func(tl *TCPListener) { tl.timeoutAccept = timeout }
}
// TCPListenerTimeoutReadWrite sets the read and write timeout for connections
// from external signing processes.
func TCPListenerTimeoutReadWrite(timeout time.Duration) TCPListenerOption {
return func(tl *TCPListener) { tl.timeoutReadWrite = timeout }
}
// tcpListener implements net.Listener.
var _ net.Listener = (*TCPListener)(nil)
// TCPListener wraps a *net.TCPListener to standardise protocol timeouts
// and potentially other tuning parameters. It also returns encrypted connections.
type TCPListener struct {
*net.TCPListener
secretConnKey ed25519.PrivKey
timeoutAccept time.Duration
timeoutReadWrite time.Duration
}
// NewTCPListener returns a listener that accepts authenticated encrypted connections
// using the given secretConnKey and the default timeout values.
func NewTCPListener(ln net.Listener, secretConnKey ed25519.PrivKey) *TCPListener {
return &TCPListener{
TCPListener: ln.(*net.TCPListener),
secretConnKey: secretConnKey,
timeoutAccept: time.Second * defaultTimeoutAcceptSeconds,
timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds,
}
}
// Accept implements net.Listener.
func (ln *TCPListener) Accept() (net.Conn, error) {
deadline := time.Now().Add(ln.timeoutAccept)
err := ln.SetDeadline(deadline)
if err != nil {
return nil, err
}
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
// Wrap the conn in our timeout and encryption wrappers
timeoutConn := newTimeoutConn(tc, ln.timeoutReadWrite)
secretConn, err := p2pconn.MakeSecretConnection(timeoutConn, ln.secretConnKey)
if err != nil {
return nil, err
}
return secretConn, nil
}
//------------------------------------------------------------------
// Unix Listener
// unixListener implements net.Listener.
var _ net.Listener = (*UnixListener)(nil)
type UnixListenerOption func(*UnixListener)
// UnixListenerTimeoutAccept sets the timeout for the listener.
// A zero time value disables the timeout.
func UnixListenerTimeoutAccept(timeout time.Duration) UnixListenerOption {
return func(ul *UnixListener) { ul.timeoutAccept = timeout }
}
// UnixListenerTimeoutReadWrite sets the read and write timeout for connections
// from external signing processes.
func UnixListenerTimeoutReadWrite(timeout time.Duration) UnixListenerOption {
return func(ul *UnixListener) { ul.timeoutReadWrite = timeout }
}
// UnixListener wraps a *net.UnixListener to standardise protocol timeouts
// and potentially other tuning parameters. It returns unencrypted connections.
type UnixListener struct {
*net.UnixListener
timeoutAccept time.Duration
timeoutReadWrite time.Duration
}
// NewUnixListener returns a listener that accepts unencrypted connections
// using the default timeout values.
func NewUnixListener(ln net.Listener) *UnixListener {
return &UnixListener{
UnixListener: ln.(*net.UnixListener),
timeoutAccept: time.Second * defaultTimeoutAcceptSeconds,
timeoutReadWrite: time.Second * defaultTimeoutReadWriteSeconds,
}
}
// Accept implements net.Listener.
func (ln *UnixListener) Accept() (net.Conn, error) {
deadline := time.Now().Add(ln.timeoutAccept)
err := ln.SetDeadline(deadline)
if err != nil {
return nil, err
}
tc, err := ln.AcceptUnix()
if err != nil {
return nil, err
}
// Wrap the conn in our timeout wrapper
conn := newTimeoutConn(tc, ln.timeoutReadWrite)
// TODO: wrap in something that authenticates
// with a MAC - https://github.com/tendermint/tendermint/issues/3099
return conn, nil
}
//------------------------------------------------------------------
// Connection
// timeoutConn implements net.Conn.
var _ net.Conn = (*timeoutConn)(nil)
// timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets.
type timeoutConn struct {
net.Conn
timeout time.Duration
}
// newTimeoutConn returns an instance of timeoutConn.
func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn {
return &timeoutConn{
conn,
timeout,
}
}
// Read implements net.Conn.
func (c timeoutConn) Read(b []byte) (n int, err error) {
// Reset deadline
deadline := time.Now().Add(c.timeout)
err = c.Conn.SetReadDeadline(deadline)
if err != nil {
return
}
return c.Conn.Read(b)
}
// Write implements net.Conn.
func (c timeoutConn) Write(b []byte) (n int, err error) {
// Reset deadline
deadline := time.Now().Add(c.timeout)
err = c.Conn.SetWriteDeadline(deadline)
if err != nil {
return
}
return c.Conn.Write(b)
}

View File

@@ -1,137 +0,0 @@
package privval
import (
"io/ioutil"
"net"
"os"
"testing"
"time"
"github.com/tendermint/tendermint/crypto/ed25519"
)
//-------------------------------------------
// helper funcs
func newPrivKey() ed25519.PrivKey {
return ed25519.GenPrivKey()
}
//-------------------------------------------
// tests
type listenerTestCase struct {
description string // For test reporting purposes.
listener net.Listener
dialer SocketDialer
}
// testUnixAddr will attempt to obtain a platform-independent temporary file
// name for a Unix socket
func testUnixAddr() (string, error) {
f, err := ioutil.TempFile("", "tendermint-privval-test-*")
if err != nil {
return "", err
}
addr := f.Name()
f.Close()
os.Remove(addr)
return addr, nil
}
func tcpListenerTestCase(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) listenerTestCase {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
tcpLn := NewTCPListener(ln, newPrivKey())
TCPListenerTimeoutAccept(timeoutAccept)(tcpLn)
TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn)
return listenerTestCase{
description: "TCP",
listener: tcpLn,
dialer: DialTCPFn(ln.Addr().String(), testTimeoutReadWrite, newPrivKey()),
}
}
func unixListenerTestCase(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) listenerTestCase {
addr, err := testUnixAddr()
if err != nil {
t.Fatal(err)
}
ln, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
}
unixLn := NewUnixListener(ln)
UnixListenerTimeoutAccept(timeoutAccept)(unixLn)
UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn)
return listenerTestCase{
description: "Unix",
listener: unixLn,
dialer: DialUnixFn(addr),
}
}
func listenerTestCases(t *testing.T, timeoutAccept, timeoutReadWrite time.Duration) []listenerTestCase {
return []listenerTestCase{
tcpListenerTestCase(t, timeoutAccept, timeoutReadWrite),
unixListenerTestCase(t, timeoutAccept, timeoutReadWrite),
}
}
func TestListenerTimeoutAccept(t *testing.T) {
for _, tc := range listenerTestCases(t, time.Millisecond, time.Second) {
_, err := tc.listener.Accept()
opErr, ok := err.(*net.OpError)
if !ok {
t.Fatalf("for %s listener, have %v, want *net.OpError", tc.description, err)
}
if have, want := opErr.Op, "accept"; have != want {
t.Errorf("for %s listener, have %v, want %v", tc.description, have, want)
}
}
}
func TestListenerTimeoutReadWrite(t *testing.T) {
const (
// This needs to be long enough s.t. the Accept will definitely succeed:
timeoutAccept = time.Second
// This can be really short but in the TCP case, the accept can
// also trigger a timeoutReadWrite. Hence, we need to give it some time.
// Note: this controls how long this test actually runs.
timeoutReadWrite = 10 * time.Millisecond
)
for _, tc := range listenerTestCases(t, timeoutAccept, timeoutReadWrite) {
go func(dialer SocketDialer) {
_, err := dialer()
if err != nil {
panic(err)
}
}(tc.dialer)
c, err := tc.listener.Accept()
if err != nil {
t.Fatal(err)
}
// this will timeout because we don't write anything:
msg := make([]byte, 200)
_, err = c.Read(msg)
opErr, ok := err.(*net.OpError)
if !ok {
t.Fatalf("for %s listener, have %v, want *net.OpError", tc.description, err)
}
if have, want := opErr.Op, "read"; have != want {
t.Errorf("for %s listener, have %v, want %v", tc.description, have, want)
}
if !opErr.Timeout() {
t.Errorf("for %s listener, got unexpected error: have %v, want Timeout error", tc.description, opErr)
}
}
}

View File

@@ -1,62 +0,0 @@
package privval
import (
"errors"
"fmt"
"net"
"github.com/tendermint/tendermint/crypto/ed25519"
"github.com/tendermint/tendermint/libs/log"
tmnet "github.com/tendermint/tendermint/libs/net"
)
// IsConnTimeout returns a boolean indicating whether the error is known to
// report that a connection timeout occurred. This detects both fundamental
// network timeouts, as well as ErrConnTimeout errors.
func IsConnTimeout(err error) bool {
_, ok := errors.Unwrap(err).(timeoutError)
switch {
case errors.As(err, &EndpointTimeoutError{}):
return true
case ok:
return true
default:
return false
}
}
// NewSignerListener creates a new SignerListenerEndpoint using the corresponding listen address
func NewSignerListener(listenAddr string, logger log.Logger) (*SignerListenerEndpoint, error) {
var listener net.Listener
protocol, address := tmnet.ProtocolAndAddress(listenAddr)
ln, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
switch protocol {
case "unix":
listener = NewUnixListener(ln)
case "tcp":
// TODO: persist this key so external signer can actually authenticate us
listener = NewTCPListener(ln, ed25519.GenPrivKey())
default:
return nil, fmt.Errorf(
"wrong listen address: expected either 'tcp' or 'unix' protocols, got %s",
protocol,
)
}
pve := NewSignerListenerEndpoint(logger.With("module", "privval"), listener)
return pve, nil
}
// GetFreeLocalhostAddrPort returns a free localhost:port address
func GetFreeLocalhostAddrPort() string {
port, err := tmnet.GetFreePort()
if err != nil {
panic(err)
}
return fmt.Sprintf("127.0.0.1:%d", port)
}

View File

@@ -1,14 +0,0 @@
package privval
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsConnTimeoutForNonTimeoutErrors(t *testing.T) {
assert.False(t, IsConnTimeout(fmt.Errorf("max retries exceeded: %w", ErrDialRetryMax)))
assert.False(t, IsConnTimeout(errors.New("completely irrelevant error")))
}