From 9b9022f8dfc1df141d6024073cb95556b8c6ec8b Mon Sep 17 00:00:00 2001 From: Alexander Simmerl Date: Fri, 16 Mar 2018 13:32:17 +0100 Subject: [PATCH] privVal: Improve SocketClient network code (#1315) Follow-up to feedback from #1286, this change simplifies the connection handling in the SocketClient and makes the communication via TCP more robust. It introduces the tcpTimeoutListener to encapsulate accept and i/o timeout handling as well as connection keep-alive, this type could likely be upgraded to handle more fine-grained tuning of the tcp stack (linger, nodelay, etc.) according to the properties we desire. The same methods should be applied to the RemoteSigner which will be overhauled when the priv_val_server is fleshed out. * require private key * simplify connect logic * break out conn upgrades to tcpTimeoutListener * extend test coverage and simplify component setup --- Gopkg.lock | 3 +- Gopkg.toml | 4 - cmd/priv_val_server/main.go | 3 +- node/node.go | 2 +- types/priv_validator/socket.go | 224 +++++++++++----------- types/priv_validator/socket_tcp.go | 66 +++++++ types/priv_validator/socket_tcp_test.go | 64 +++++++ types/priv_validator/socket_test.go | 242 +++++++++++++++--------- 8 files changed, 400 insertions(+), 208 deletions(-) create mode 100644 types/priv_validator/socket_tcp.go create mode 100644 types/priv_validator/socket_tcp_test.go diff --git a/Gopkg.lock b/Gopkg.lock index 10739e8eb..91e0b41e2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -305,7 +305,6 @@ "idna", "internal/timeseries", "lex/httplex", - "netutil", "trace" ] revision = "cbe0f9307d0156177f9dd5dc85da1a31abc5f2fb" @@ -373,6 +372,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "fe167dd9055ba9a4016e7bdad88da263372bca7ebdcebf5c81c609f396e605a3" + inputs-digest = "ed9db0be72a900f4812675f683db20eff9d64ef4511dc00ad29a810da65909c2" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index b963fe13c..61406ad66 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -90,10 +90,6 @@ name = "google.golang.org/grpc" version = "1.7.3" -[[constraint]] - branch = "master" - name = "golang.org/x/net" - [prune] go-tests = true unused-packages = true diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index 0d18f8ed2..9f3ec73ca 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -4,6 +4,7 @@ import ( "flag" "os" + crypto "github.com/tendermint/go-crypto" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" @@ -36,7 +37,7 @@ func main() { *chainID, *addr, privVal, - nil, + crypto.GenPrivKeyEd25519(), ) err := rs.Start() if err != nil { diff --git a/node/node.go b/node/node.go index 83ac50ec6..dffdb83e8 100644 --- a/node/node.go +++ b/node/node.go @@ -183,7 +183,7 @@ func NewNode(config *cfg.Config, pvsc = priv_val.NewSocketClient( logger.With("module", "priv_val"), config.PrivValidatorListenAddr, - &privKey, + privKey, ) ) diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index 05bc77710..26cab72b9 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -11,39 +11,53 @@ import ( wire "github.com/tendermint/go-wire" cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" - "golang.org/x/net/netutil" p2pconn "github.com/tendermint/tendermint/p2p/conn" "github.com/tendermint/tendermint/types" ) const ( - defaultConnDeadlineSeconds = 3 - defaultConnWaitSeconds = 60 - defaultDialRetries = 10 - defaultSignersMax = 1 + defaultAcceptDeadlineSeconds = 3 + defaultConnDeadlineSeconds = 3 + defaultConnHeartBeatSeconds = 30 + defaultConnWaitSeconds = 60 + defaultDialRetries = 10 ) // Socket errors. var ( - ErrDialRetryMax = errors.New("Error max client retries") - ErrConnWaitTimeout = errors.New("Error waiting for external connection") - ErrConnTimeout = errors.New("Error connection timed out") + ErrDialRetryMax = errors.New("dialed maximum retries") + ErrConnWaitTimeout = errors.New("waited for remote signer for too long") + ErrConnTimeout = errors.New("remote signer timed out") ) var ( - connDeadline = time.Second * defaultConnDeadlineSeconds + acceptDeadline = time.Second + defaultAcceptDeadlineSeconds + connDeadline = time.Second * defaultConnDeadlineSeconds + connHeartbeat = time.Second * defaultConnHeartBeatSeconds ) // SocketClientOption sets an optional parameter on the SocketClient. type SocketClientOption func(*SocketClient) +// SocketClientAcceptDeadline sets the deadline for the SocketClient listener. +// A zero time value disables the deadline. +func SocketClientAcceptDeadline(deadline time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.acceptDeadline = deadline } +} + // SocketClientConnDeadline sets the read and write deadline for connections // from external signing processes. func SocketClientConnDeadline(deadline time.Duration) SocketClientOption { return func(sc *SocketClient) { sc.connDeadline = deadline } } +// SocketClientHeartbeat sets the period on which to check the liveness of the +// connected Signer connections. +func SocketClientHeartbeat(period time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.connHeartbeat = period } +} + // SocketClientConnWait sets the timeout duration before connection of external // signing processes are considered to be unsuccessful. func SocketClientConnWait(timeout time.Duration) SocketClientOption { @@ -56,9 +70,11 @@ type SocketClient struct { cmn.BaseService addr string + acceptDeadline time.Duration connDeadline time.Duration + connHeartbeat time.Duration connWaitTimeout time.Duration - privKey *crypto.PrivKeyEd25519 + privKey crypto.PrivKeyEd25519 conn net.Conn listener net.Listener @@ -71,11 +87,13 @@ var _ types.PrivValidator2 = (*SocketClient)(nil) func NewSocketClient( logger log.Logger, socketAddr string, - privKey *crypto.PrivKeyEd25519, + privKey crypto.PrivKeyEd25519, ) *SocketClient { sc := &SocketClient{ addr: socketAddr, - connDeadline: time.Second * defaultConnDeadlineSeconds, + acceptDeadline: acceptDeadline, + connDeadline: connDeadline, + connHeartbeat: connHeartbeat, connWaitTimeout: time.Second * defaultConnWaitSeconds, privKey: privKey, } @@ -85,57 +103,6 @@ func NewSocketClient( return sc } -// OnStart implements cmn.Service. -func (sc *SocketClient) OnStart() error { - if sc.listener == nil { - if err := sc.listen(); err != nil { - sc.Logger.Error( - "OnStart", - "err", errors.Wrap(err, "failed to listen"), - ) - - return err - } - } - - conn, err := sc.waitConnection() - if err != nil { - sc.Logger.Error( - "OnStart", - "err", errors.Wrap(err, "failed to accept connection"), - ) - - return err - } - - sc.conn = conn - - return nil -} - -// OnStop implements cmn.Service. -func (sc *SocketClient) OnStop() { - sc.BaseService.OnStop() - - if sc.conn != nil { - if err := sc.conn.Close(); err != nil { - sc.Logger.Error( - "OnStop", - "err", errors.Wrap(err, "failed to close connection"), - ) - } - } - - if sc.listener != nil { - if err := sc.listener.Close(); err != nil { - sc.Logger.Error( - "OnStop", - "err", errors.Wrap(err, "failed to close listener"), - ) - } - } -} - // GetAddress implements PrivValidator. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. func (sc *SocketClient) GetAddress() types.Address { @@ -240,6 +207,53 @@ func (sc *SocketClient) SignHeartbeat( return nil } +// OnStart implements cmn.Service. +func (sc *SocketClient) OnStart() error { + if err := sc.listen(); err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to listen"), + ) + + return err + } + + conn, err := sc.waitConnection() + if err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to accept connection"), + ) + + return err + } + + sc.conn = conn + + return nil +} + +// OnStop implements cmn.Service. +func (sc *SocketClient) OnStop() { + if sc.conn != nil { + if err := sc.conn.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close connection"), + ) + } + } + + if sc.listener != nil { + if err := sc.listener.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close listener"), + ) + } + } +} + func (sc *SocketClient) acceptConnection() (net.Conn, error) { conn, err := sc.listener.Accept() if err != nil { @@ -250,17 +264,11 @@ func (sc *SocketClient) acceptConnection() (net.Conn, error) { } - if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil { + conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) + if err != nil { return nil, err } - if sc.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) - if err != nil { - return nil, err - } - } - return conn, nil } @@ -270,7 +278,12 @@ func (sc *SocketClient) listen() error { return err } - sc.listener = netutil.LimitListener(ln, defaultSignersMax) + sc.listener = newTCPTimeoutListener( + ln, + sc.acceptDeadline, + sc.connDeadline, + sc.connHeartbeat, + ) return nil } @@ -297,6 +310,9 @@ func (sc *SocketClient) waitConnection() (net.Conn, error) { case conn := <-connc: return conn, nil case err := <-errc: + if _, ok := err.(timeoutError); ok { + return nil, errors.Wrap(ErrConnWaitTimeout, err.Error()) + } return nil, err case <-time.After(sc.connWaitTimeout): return nil, ErrConnWaitTimeout @@ -319,8 +335,7 @@ func RemoteSignerConnRetries(retries int) RemoteSignerOption { return func(ss *RemoteSigner) { ss.connRetries = retries } } -// RemoteSigner implements PrivValidator. -// It responds to requests over a socket +// RemoteSigner implements PrivValidator by dialing to a socket. type RemoteSigner struct { cmn.BaseService @@ -328,19 +343,18 @@ type RemoteSigner struct { chainID string connDeadline time.Duration connRetries int - privKey *crypto.PrivKeyEd25519 + privKey crypto.PrivKeyEd25519 privVal PrivValidator conn net.Conn } -// NewRemoteSigner returns an instance of -// RemoteSigner. +// NewRemoteSigner returns an instance of RemoteSigner. func NewRemoteSigner( logger log.Logger, chainID, socketAddr string, privVal PrivValidator, - privKey *crypto.PrivKeyEd25519, + privKey crypto.PrivKeyEd25519, ) *RemoteSigner { rs := &RemoteSigner{ addr: socketAddr, @@ -382,17 +396,12 @@ func (rs *RemoteSigner) OnStop() { } func (rs *RemoteSigner) connect() (net.Conn, error) { - retries := defaultDialRetries - -RETRY_LOOP: - for retries > 0 { + for retries := rs.connRetries; retries > 0; retries-- { // Don't sleep if it is the first retry. - if retries != defaultDialRetries { + if retries != rs.connRetries { time.Sleep(rs.connDeadline) } - retries-- - conn, err := cmn.Connect(rs.addr) if err != nil { rs.Logger.Error( @@ -401,7 +410,7 @@ RETRY_LOOP: "err", errors.Wrap(err, "connection failed"), ) - continue RETRY_LOOP + continue } if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { @@ -412,16 +421,14 @@ RETRY_LOOP: continue } - if rs.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap()) - if err != nil { - rs.Logger.Error( - "sc connect", - "err", errors.Wrap(err, "encrypting connection failed"), - ) + conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap()) + if err != nil { + rs.Logger.Error( + "connect", + "err", errors.Wrap(err, "encrypting connection failed"), + ) - continue RETRY_LOOP - } + continue } return conn, nil @@ -444,7 +451,7 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) { return } - var res PrivValidatorSocketMsg + var res PrivValMsg switch r := req.(type) { case *PubKeyMsg: @@ -487,12 +494,11 @@ const ( msgTypeSignHeartbeat = byte(0x12) ) -// PrivValidatorSocketMsg is a message sent between PrivValidatorSocket client -// and server. -type PrivValidatorSocketMsg interface{} +// PrivValMsg is sent between RemoteSigner and SocketClient. +type PrivValMsg interface{} var _ = wire.RegisterInterface( - struct{ PrivValidatorSocketMsg }{}, + struct{ PrivValMsg }{}, wire.ConcreteType{&PubKeyMsg{}, msgTypePubKey}, wire.ConcreteType{&SignVoteMsg{}, msgTypeSignVote}, wire.ConcreteType{&SignProposalMsg{}, msgTypeSignProposal}, @@ -519,27 +525,27 @@ type SignHeartbeatMsg struct { Heartbeat *types.Heartbeat } -func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) { +func readMsg(r io.Reader) (PrivValMsg, error) { var ( n int err error ) - read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err) + read := wire.ReadBinary(struct{ PrivValMsg }{}, r, 0, &n, &err) if err != nil { - if opErr, ok := err.(*net.OpError); ok { - return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + if _, ok := err.(timeoutError); ok { + return nil, errors.Wrap(ErrConnTimeout, err.Error()) } return nil, err } - w, ok := read.(struct{ PrivValidatorSocketMsg }) + w, ok := read.(struct{ PrivValMsg }) if !ok { return nil, errors.New("unknown type") } - return w.PrivValidatorSocketMsg, nil + return w.PrivValMsg, nil } func writeMsg(w io.Writer, msg interface{}) error { @@ -549,9 +555,9 @@ func writeMsg(w io.Writer, msg interface{}) error { ) // TODO(xla): This extra wrap should be gone with the sdk-2 update. - wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err) - if opErr, ok := err.(*net.OpError); ok { - return errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + wire.WriteBinary(struct{ PrivValMsg }{msg}, w, &n, &err) + if _, ok := err.(timeoutError); ok { + return errors.Wrap(ErrConnTimeout, err.Error()) } return err diff --git a/types/priv_validator/socket_tcp.go b/types/priv_validator/socket_tcp.go new file mode 100644 index 000000000..2421eb9f4 --- /dev/null +++ b/types/priv_validator/socket_tcp.go @@ -0,0 +1,66 @@ +package types + +import ( + "net" + "time" +) + +// timeoutError can be used to check if an error returned from the netp package +// was due to a timeout. +type timeoutError interface { + Timeout() bool +} + +// tcpTimeoutListener implements net.Listener. +var _ net.Listener = (*tcpTimeoutListener)(nil) + +// tcpTimeoutListener wraps a *net.TCPListener to standardise protocol timeouts +// and potentially other tuning parameters. +type tcpTimeoutListener struct { + *net.TCPListener + + acceptDeadline time.Duration + connDeadline time.Duration + period time.Duration +} + +// newTCPTimeoutListener returns an instance of tcpTimeoutListener. +func newTCPTimeoutListener( + ln net.Listener, + acceptDeadline, connDeadline time.Duration, + period time.Duration, +) tcpTimeoutListener { + return tcpTimeoutListener{ + TCPListener: ln.(*net.TCPListener), + acceptDeadline: acceptDeadline, + connDeadline: connDeadline, + period: period, + } +} + +// Accept implements net.Listener. +func (ln tcpTimeoutListener) Accept() (net.Conn, error) { + err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) + if err != nil { + return nil, err + } + + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + + if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil { + return nil, err + } + + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + + if err := tc.SetKeepAlivePeriod(ln.period); err != nil { + return nil, err + } + + return tc, nil +} diff --git a/types/priv_validator/socket_tcp_test.go b/types/priv_validator/socket_tcp_test.go new file mode 100644 index 000000000..cd95ab0b9 --- /dev/null +++ b/types/priv_validator/socket_tcp_test.go @@ -0,0 +1,64 @@ +package types + +import ( + "net" + "testing" + "time" +) + +func TestTCPTimeoutListenerAcceptDeadline(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + ln = newTCPTimeoutListener(ln, time.Millisecond, time.Second, time.Second) + + _, err = ln.Accept() + opErr, ok := err.(*net.OpError) + if !ok { + t.Fatalf("have %v, want *net.OpError", err) + } + + if have, want := opErr.Op, "accept"; have != want { + t.Errorf("have %v, want %v", have, want) + } +} + +func TestTCPTimeoutListenerConnDeadline(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + ln = newTCPTimeoutListener(ln, time.Second, time.Millisecond, time.Second) + + donec := make(chan struct{}) + go func(ln net.Listener) { + defer close(donec) + + c, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + + time.Sleep(2 * time.Millisecond) + + _, err = c.Write([]byte("foo")) + opErr, ok := err.(*net.OpError) + if !ok { + t.Fatalf("have %v, want *net.OpError", err) + } + + if have, want := opErr.Op, "write"; have != want { + t.Errorf("have %v, want %v", have, want) + } + }(ln) + + _, err = net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + <-donec +} diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go index 36f09f40c..2859c9452 100644 --- a/types/priv_validator/socket_test.go +++ b/types/priv_validator/socket_test.go @@ -1,6 +1,8 @@ package types import ( + "fmt" + "net" "testing" "time" @@ -12,57 +14,55 @@ import ( cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" + p2pconn "github.com/tendermint/tendermint/p2p/conn" "github.com/tendermint/tendermint/types" ) func TestSocketClientAddress(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() defer rs.Stop() serverAddr, err := rs.privVal.Address() - require.NoError(err) + require.NoError(t, err) clientAddr, err := sc.Address() - require.NoError(err) + require.NoError(t, err) - assert.Equal(serverAddr, clientAddr) + assert.Equal(t, serverAddr, clientAddr) // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. - assert.Equal(serverAddr, sc.GetAddress()) + assert.Equal(t, serverAddr, sc.GetAddress()) } func TestSocketClientPubKey(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() defer rs.Stop() clientKey, err := sc.PubKey() - require.NoError(err) + require.NoError(t, err) privKey, err := rs.privVal.PubKey() - require.NoError(err) + require.NoError(t, err) - assert.Equal(privKey, clientKey) + assert.Equal(t, privKey, clientKey) // TODO(xla): Remove when PrivValidator2 replaced PrivValidator. - assert.Equal(privKey, sc.GetPubKey()) + assert.Equal(t, privKey, sc.GetPubKey()) } func TestSocketClientProposal(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() privProposal = &types.Proposal{Timestamp: ts} @@ -71,16 +71,15 @@ func TestSocketClientProposal(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignProposal(chainID, privProposal)) - require.NoError(sc.SignProposal(chainID, clientProposal)) - assert.Equal(privProposal.Signature, clientProposal.Signature) + require.NoError(t, rs.privVal.SignProposal(chainID, privProposal)) + require.NoError(t, sc.SignProposal(chainID, clientProposal)) + assert.Equal(t, privProposal.Signature, clientProposal.Signature) } func TestSocketClientVote(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() vType = types.VoteTypePrecommit @@ -90,16 +89,15 @@ func TestSocketClientVote(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignVote(chainID, want)) - require.NoError(sc.SignVote(chainID, have)) - assert.Equal(want.Signature, have.Signature) + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) } func TestSocketClientHeartbeat(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - chainID = cmn.RandStr(12) - sc, rs = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) want = &types.Heartbeat{} have = &types.Heartbeat{} @@ -107,79 +105,133 @@ func TestSocketClientHeartbeat(t *testing.T) { defer sc.Stop() defer rs.Stop() - require.NoError(rs.privVal.SignHeartbeat(chainID, want)) - require.NoError(sc.SignHeartbeat(chainID, have)) - assert.Equal(want.Signature, have.Signature) + require.NoError(t, rs.privVal.SignHeartbeat(chainID, want)) + require.NoError(t, sc.SignHeartbeat(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) } -func TestSocketClientDeadline(t *testing.T) { +func TestSocketClientAcceptDeadline(t *testing.T) { var ( - assert, require = assert.New(t), require.New(t) - readyc = make(chan struct{}) - sc = NewSocketClient( + sc = NewSocketClient( log.TestingLogger(), "127.0.0.1:0", - nil, + crypto.GenPrivKeyEd25519(), ) ) defer sc.Stop() - SocketClientConnDeadline(time.Millisecond)(sc) + SocketClientAcceptDeadline(time.Millisecond)(sc) - require.NoError(sc.listen()) + assert.Equal(t, errors.Cause(sc.Start()), ErrConnWaitTimeout) +} + +func TestSocketClientDeadline(t *testing.T) { + var ( + addr = testFreeAddr(t) + listenc = make(chan struct{}) + sc = NewSocketClient( + log.TestingLogger(), + addr, + crypto.GenPrivKeyEd25519(), + ) + ) + + SocketClientConnDeadline(10 * time.Millisecond)(sc) + SocketClientConnWait(500 * time.Millisecond)(sc) go func(sc *SocketClient) { - require.NoError(sc.Start()) - assert.True(sc.IsRunning()) + defer close(listenc) - readyc <- struct{}{} + require.NoError(t, sc.Start()) + + assert.True(t, sc.IsRunning()) }(sc) - _, err := cmn.Connect(sc.listener.Addr().String()) - require.NoError(err) + for { + conn, err := cmn.Connect(addr) + if err != nil { + continue + } - <-readyc + _, err = p2pconn.MakeSecretConnection( + conn, + crypto.GenPrivKeyEd25519().Wrap(), + ) + if err == nil { + break + } + } - _, err = sc.PubKey() - assert.Equal(errors.Cause(err), ErrConnTimeout) + <-listenc + + // Sleep to guarantee deadline has been hit. + time.Sleep(20 * time.Microsecond) + + _, err := sc.PubKey() + assert.Equal(t, errors.Cause(err), ErrConnTimeout) } func TestSocketClientWait(t *testing.T) { - var ( - assert, _ = assert.New(t), require.New(t) - logger = log.TestingLogger() - privKey = crypto.GenPrivKeyEd25519() - sc = NewSocketClient( - logger, - "127.0.0.1:0", - &privKey, - ) + sc := NewSocketClient( + log.TestingLogger(), + "127.0.0.1:0", + crypto.GenPrivKeyEd25519(), ) defer sc.Stop() SocketClientConnWait(time.Millisecond)(sc) - assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error()) + assert.Equal(t, errors.Cause(sc.Start()), ErrConnWaitTimeout) } func TestRemoteSignerRetry(t *testing.T) { var ( - assert, _ = assert.New(t), require.New(t) - privKey = crypto.GenPrivKeyEd25519() - rs = NewRemoteSigner( - log.TestingLogger(), - cmn.RandStr(12), - "127.0.0.1:0", - NewTestPrivValidator(types.GenSigner()), - &privKey, - ) + attemptc = make(chan int) + retries = 2 + ) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func(ln net.Listener, attemptc chan<- int) { + attempts := 0 + + for { + conn, err := ln.Accept() + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + + attempts++ + + if attempts == retries { + attemptc <- attempts + break + } + } + }(ln, attemptc) + + rs := NewRemoteSigner( + log.TestingLogger(), + cmn.RandStr(12), + ln.Addr().String(), + NewTestPrivValidator(types.GenSigner()), + crypto.GenPrivKeyEd25519(), ) defer rs.Stop() RemoteSignerConnDeadline(time.Millisecond)(rs) - RemoteSignerConnRetries(2)(rs) + RemoteSignerConnRetries(retries)(rs) - assert.EqualError(rs.Start(), ErrDialRetryMax.Error()) + assert.Equal(t, errors.Cause(rs.Start()), ErrDialRetryMax) + + select { + case attempts := <-attemptc: + assert.Equal(t, retries, attempts) + case <-time.After(100 * time.Millisecond): + t.Error("expected remote to observe connection attempts") + } } func testSetupSocketPair( @@ -187,40 +239,48 @@ func testSetupSocketPair( chainID string, ) (*SocketClient, *RemoteSigner) { var ( - assert, require = assert.New(t), require.New(t) - logger = log.TestingLogger() - signer = types.GenSigner() - clientPrivKey = crypto.GenPrivKeyEd25519() - remotePrivKey = crypto.GenPrivKeyEd25519() - privVal = NewTestPrivValidator(signer) - readyc = make(chan struct{}) - sc = NewSocketClient( + addr = testFreeAddr(t) + logger = log.TestingLogger() + signer = types.GenSigner() + privVal = NewTestPrivValidator(signer) + readyc = make(chan struct{}) + rs = NewRemoteSigner( logger, - "127.0.0.1:0", - &clientPrivKey, + chainID, + addr, + privVal, + crypto.GenPrivKeyEd25519(), + ) + sc = NewSocketClient( + logger, + addr, + crypto.GenPrivKeyEd25519(), ) ) - require.NoError(sc.listen()) - go func(sc *SocketClient) { - require.NoError(sc.Start()) - assert.True(sc.IsRunning()) + require.NoError(t, sc.Start()) + assert.True(t, sc.IsRunning()) readyc <- struct{}{} }(sc) - rs := NewRemoteSigner( - logger, - chainID, - sc.listener.Addr().String(), - privVal, - &remotePrivKey, - ) - require.NoError(rs.Start()) - assert.True(rs.IsRunning()) + RemoteSignerConnDeadline(time.Millisecond)(rs) + RemoteSignerConnRetries(1e6)(rs) + + require.NoError(t, rs.Start()) + assert.True(t, rs.IsRunning()) <-readyc return sc, rs } + +// testFreeAddr claims a free port so we don't block on listener being ready. +func testFreeAddr(t *testing.T) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) +}