mirror of
https://github.com/tendermint/tendermint.git
synced 2026-01-05 13:05:09 +00:00
service: plumb contexts to all (most) threads (#7363)
This continues the push of plumbing contexts through tendermint. I attempted to find all goroutines in the production code (non-test) and made sure that these threads would exit when their contexts were canceled, and I believe this PR does that.
This commit is contained in:
@@ -24,6 +24,7 @@ type WSOptions struct {
|
||||
ReadWait time.Duration // deadline for any read op
|
||||
WriteWait time.Duration // deadline for any write op
|
||||
PingPeriod time.Duration // frequency with which pings are sent
|
||||
SkipMetrics bool // do not keep metrics for ping/pong latency
|
||||
}
|
||||
|
||||
// DefaultWSOptions returns default WS options.
|
||||
@@ -117,8 +118,6 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e
|
||||
Address: parsedURL.GetTrimmedHostWithPath(),
|
||||
Dialer: dialFn,
|
||||
Endpoint: endpoint,
|
||||
PingPongLatencyTimer: metrics.NewTimer(),
|
||||
|
||||
maxReconnectAttempts: opts.MaxReconnectAttempts,
|
||||
readWait: opts.ReadWait,
|
||||
writeWait: opts.WriteWait,
|
||||
@@ -127,6 +126,14 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e
|
||||
|
||||
// sentIDs: make(map[types.JSONRPCIntID]bool),
|
||||
}
|
||||
|
||||
switch opts.SkipMetrics {
|
||||
case true:
|
||||
c.PingPongLatencyTimer = metrics.NilTimer{}
|
||||
case false:
|
||||
c.PingPongLatencyTimer = metrics.NewTimer()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -143,8 +150,8 @@ func (c *WSClient) String() string {
|
||||
}
|
||||
|
||||
// Start dials the specified service address and starts the I/O routines.
|
||||
func (c *WSClient) Start() error {
|
||||
if err := c.RunState.Start(); err != nil {
|
||||
func (c *WSClient) Start(ctx context.Context) error {
|
||||
if err := c.RunState.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
err := c.dial()
|
||||
@@ -162,8 +169,8 @@ func (c *WSClient) Start() error {
|
||||
// channel is unbuffered.
|
||||
c.backlog = make(chan rpctypes.RPCRequest, 1)
|
||||
|
||||
c.startReadWriteRoutines()
|
||||
go c.reconnectRoutine()
|
||||
c.startReadWriteRoutines(ctx)
|
||||
go c.reconnectRoutine(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -173,6 +180,7 @@ func (c *WSClient) Stop() error {
|
||||
if err := c.RunState.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only close user-facing channels when we can't write to them
|
||||
c.wg.Wait()
|
||||
close(c.ResponsesCh)
|
||||
@@ -253,7 +261,7 @@ func (c *WSClient) dial() error {
|
||||
|
||||
// reconnect tries to redial up to maxReconnectAttempts with exponential
|
||||
// backoff.
|
||||
func (c *WSClient) reconnect() error {
|
||||
func (c *WSClient) reconnect(ctx context.Context) error {
|
||||
attempt := uint(0)
|
||||
|
||||
c.mtx.Lock()
|
||||
@@ -265,13 +273,21 @@ func (c *WSClient) reconnect() error {
|
||||
c.mtx.Unlock()
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(0)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// nolint:gosec // G404: Use of weak random number generator
|
||||
jitter := time.Duration(mrand.Float64() * float64(time.Second)) // 1s == (1e9 ns)
|
||||
backoffDuration := jitter + ((1 << attempt) * time.Second)
|
||||
|
||||
c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration)
|
||||
time.Sleep(backoffDuration)
|
||||
timer.Reset(backoffDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
err := c.dial()
|
||||
if err != nil {
|
||||
@@ -292,11 +308,11 @@ func (c *WSClient) reconnect() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WSClient) startReadWriteRoutines() {
|
||||
func (c *WSClient) startReadWriteRoutines(ctx context.Context) {
|
||||
c.wg.Add(2)
|
||||
c.readRoutineQuit = make(chan struct{})
|
||||
go c.readRoutine()
|
||||
go c.writeRoutine()
|
||||
go c.readRoutine(ctx)
|
||||
go c.writeRoutine(ctx)
|
||||
}
|
||||
|
||||
func (c *WSClient) processBacklog() error {
|
||||
@@ -320,13 +336,15 @@ func (c *WSClient) processBacklog() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WSClient) reconnectRoutine() {
|
||||
func (c *WSClient) reconnectRoutine(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case originalError := <-c.reconnectAfter:
|
||||
// wait until writeRoutine and readRoutine finish
|
||||
c.wg.Wait()
|
||||
if err := c.reconnect(); err != nil {
|
||||
if err := c.reconnect(ctx); err != nil {
|
||||
c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError)
|
||||
if err = c.Stop(); err != nil {
|
||||
c.Logger.Error("failed to stop conn", "error", err)
|
||||
@@ -338,6 +356,8 @@ func (c *WSClient) reconnectRoutine() {
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-c.reconnectAfter:
|
||||
default:
|
||||
break LOOP
|
||||
@@ -345,18 +365,15 @@ func (c *WSClient) reconnectRoutine() {
|
||||
}
|
||||
err := c.processBacklog()
|
||||
if err == nil {
|
||||
c.startReadWriteRoutines()
|
||||
c.startReadWriteRoutines(ctx)
|
||||
}
|
||||
|
||||
case <-c.Quit():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The client ensures that there is at most one writer to a connection by
|
||||
// executing all writes from this goroutine.
|
||||
func (c *WSClient) writeRoutine() {
|
||||
func (c *WSClient) writeRoutine(ctx context.Context) {
|
||||
var ticker *time.Ticker
|
||||
if c.pingPeriod > 0 {
|
||||
// ticker with a predefined period
|
||||
@@ -408,7 +425,7 @@ func (c *WSClient) writeRoutine() {
|
||||
c.Logger.Debug("sent ping")
|
||||
case <-c.readRoutineQuit:
|
||||
return
|
||||
case <-c.Quit():
|
||||
case <-ctx.Done():
|
||||
if err := c.conn.WriteMessage(
|
||||
websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||||
@@ -422,7 +439,7 @@ func (c *WSClient) writeRoutine() {
|
||||
|
||||
// The client ensures that there is at most one reader to a connection by
|
||||
// executing all reads from this goroutine.
|
||||
func (c *WSClient) readRoutine() {
|
||||
func (c *WSClient) readRoutine(ctx context.Context) {
|
||||
defer func() {
|
||||
c.conn.Close()
|
||||
// err != nil {
|
||||
@@ -494,7 +511,8 @@ func (c *WSClient) readRoutine() {
|
||||
c.Logger.Info("got response", "id", response.ID, "result", response.Result)
|
||||
|
||||
select {
|
||||
case <-c.Quit():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case c.ResponsesCh <- response:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,11 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fortytw2/leaktest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -64,25 +65,26 @@ func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(leaktest.Check(t))
|
||||
|
||||
// start server
|
||||
h := &myHandler{}
|
||||
s := httptest.NewServer(h)
|
||||
defer s.Close()
|
||||
|
||||
c := startClient(t, "//"+s.Listener.Addr().String())
|
||||
defer c.Stop() // nolint:errcheck // ignore for tests
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
go callWgDoneOnResult(t, c, &wg)
|
||||
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
|
||||
|
||||
go handleResponses(ctx, t, c)
|
||||
|
||||
h.mtx.Lock()
|
||||
h.closeConnAfterRead = true
|
||||
h.mtx.Unlock()
|
||||
|
||||
// results in WS read error, no send retry because write succeeded
|
||||
call(t, "a", c)
|
||||
call(ctx, t, "a", c)
|
||||
|
||||
// expect to reconnect almost immediately
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
@@ -91,23 +93,23 @@ func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
|
||||
h.mtx.Unlock()
|
||||
|
||||
// should succeed
|
||||
call(t, "b", c)
|
||||
|
||||
wg.Wait()
|
||||
call(ctx, t, "b", c)
|
||||
}
|
||||
|
||||
func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
t.Cleanup(leaktest.Check(t))
|
||||
|
||||
// start server
|
||||
h := &myHandler{}
|
||||
s := httptest.NewServer(h)
|
||||
defer s.Close()
|
||||
|
||||
c := startClient(t, "//"+s.Listener.Addr().String())
|
||||
defer c.Stop() // nolint:errcheck // ignore for tests
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
wg.Add(2)
|
||||
go callWgDoneOnResult(t, c, &wg)
|
||||
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
|
||||
|
||||
go handleResponses(ctx, t, c)
|
||||
|
||||
// hacky way to abort the connection before write
|
||||
if err := c.conn.Close(); err != nil {
|
||||
@@ -115,30 +117,32 @@ func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
// results in WS write error, the client should resend on reconnect
|
||||
call(t, "a", c)
|
||||
call(ctx, t, "a", c)
|
||||
|
||||
// expect to reconnect almost immediately
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// should succeed
|
||||
call(t, "b", c)
|
||||
|
||||
wg.Wait()
|
||||
call(ctx, t, "b", c)
|
||||
}
|
||||
|
||||
func TestWSClientReconnectFailure(t *testing.T) {
|
||||
t.Cleanup(leaktest.Check(t))
|
||||
|
||||
// start server
|
||||
h := &myHandler{}
|
||||
s := httptest.NewServer(h)
|
||||
|
||||
c := startClient(t, "//"+s.Listener.Addr().String())
|
||||
defer c.Stop() // nolint:errcheck // ignore for tests
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ResponsesCh:
|
||||
case <-c.Quit():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -152,9 +156,9 @@ func TestWSClientReconnectFailure(t *testing.T) {
|
||||
|
||||
// results in WS write error
|
||||
// provide timeout to avoid blocking
|
||||
ctx, cancel := context.WithTimeout(context.Background(), wsCallTimeout)
|
||||
cctx, cancel := context.WithTimeout(ctx, wsCallTimeout)
|
||||
defer cancel()
|
||||
if err := c.Call(ctx, "a", make(map[string]interface{})); err != nil {
|
||||
if err := c.Call(cctx, "a", make(map[string]interface{})); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
@@ -164,7 +168,7 @@ func TestWSClientReconnectFailure(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// client should block on this
|
||||
call(t, "b", c)
|
||||
call(ctx, t, "b", c)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -178,44 +182,68 @@ func TestWSClientReconnectFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNotBlockingOnStop(t *testing.T) {
|
||||
timeout := 2 * time.Second
|
||||
t.Cleanup(leaktest.Check(t))
|
||||
|
||||
timeout := 3 * time.Second
|
||||
s := httptest.NewServer(&myHandler{})
|
||||
c := startClient(t, "//"+s.Listener.Addr().String())
|
||||
c.Call(context.Background(), "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests
|
||||
defer s.Close()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
|
||||
c.Call(ctx, "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests
|
||||
// Let the readRoutine get around to blocking
|
||||
time.Sleep(time.Second)
|
||||
passCh := make(chan struct{})
|
||||
go func() {
|
||||
// Unless we have a non-blocking write to ResponsesCh from readRoutine
|
||||
// this blocks forever ont the waitgroup
|
||||
err := c.Stop()
|
||||
require.NoError(t, err)
|
||||
passCh <- struct{}{}
|
||||
cancel()
|
||||
require.NoError(t, c.Stop())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case passCh <- struct{}{}:
|
||||
}
|
||||
}()
|
||||
|
||||
runtime.Gosched() // hacks: force context switch
|
||||
|
||||
select {
|
||||
case <-passCh:
|
||||
// Pass
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
|
||||
timeout.Seconds())
|
||||
if c.IsRunning() {
|
||||
t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
|
||||
timeout.Seconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startClient(t *testing.T, addr string) *WSClient {
|
||||
c, err := NewWS(addr, "/websocket")
|
||||
func startClient(ctx context.Context, t *testing.T, addr string) *WSClient {
|
||||
t.Helper()
|
||||
opts := DefaultWSOptions()
|
||||
opts.SkipMetrics = true
|
||||
c, err := NewWSWithOptions(addr, "/websocket", opts)
|
||||
|
||||
require.Nil(t, err)
|
||||
err = c.Start()
|
||||
err = c.Start(ctx)
|
||||
require.Nil(t, err)
|
||||
c.SetLogger(log.TestingLogger())
|
||||
c.Logger = log.TestingLogger()
|
||||
return c
|
||||
}
|
||||
|
||||
func call(t *testing.T, method string, c *WSClient) {
|
||||
err := c.Call(context.Background(), method, make(map[string]interface{}))
|
||||
require.NoError(t, err)
|
||||
func call(ctx context.Context, t *testing.T, method string, c *WSClient) {
|
||||
t.Helper()
|
||||
|
||||
err := c.Call(ctx, method, make(map[string]interface{}))
|
||||
if ctx.Err() == nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
|
||||
func handleResponses(ctx context.Context, t *testing.T, c *WSClient) {
|
||||
t.Helper()
|
||||
|
||||
for {
|
||||
select {
|
||||
case resp := <-c.ResponsesCh:
|
||||
@@ -224,9 +252,9 @@ func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
|
||||
return
|
||||
}
|
||||
if resp.Result != nil {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
case <-c.Quit():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,10 +35,6 @@ const (
|
||||
testVal = "acbd"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
type ResultEcho struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
@@ -85,13 +81,16 @@ func EchoDataBytesResult(ctx *rpctypes.Context, v tmbytes.HexBytes) (*ResultEcho
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
setup()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
setup(ctx)
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// launch unix and tcp servers
|
||||
func setup() {
|
||||
func setup(ctx context.Context) {
|
||||
logger := log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false)
|
||||
|
||||
cmd := exec.Command("rm", "-f", unixSocket)
|
||||
@@ -115,7 +114,7 @@ func setup() {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
if err := server.Serve(listener1, mux, tcpLogger, config); err != nil {
|
||||
if err := server.Serve(ctx, listener1, mux, tcpLogger, config); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
@@ -131,7 +130,7 @@ func setup() {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
if err := server.Serve(listener2, mux2, unixLogger, config); err != nil {
|
||||
if err := server.Serve(ctx, listener2, mux2, unixLogger, config); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
@@ -140,7 +139,7 @@ func setup() {
|
||||
time.Sleep(time.Second * 2)
|
||||
}
|
||||
|
||||
func echoViaHTTP(cl client.Caller, val string) (string, error) {
|
||||
func echoViaHTTP(ctx context.Context, cl client.Caller, val string) (string, error) {
|
||||
params := map[string]interface{}{
|
||||
"arg": val,
|
||||
}
|
||||
@@ -151,7 +150,7 @@ func echoViaHTTP(cl client.Caller, val string) (string, error) {
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
func echoIntViaHTTP(cl client.Caller, val int) (int, error) {
|
||||
func echoIntViaHTTP(ctx context.Context, cl client.Caller, val int) (int, error) {
|
||||
params := map[string]interface{}{
|
||||
"arg": val,
|
||||
}
|
||||
@@ -162,7 +161,7 @@ func echoIntViaHTTP(cl client.Caller, val int) (int, error) {
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) {
|
||||
func echoBytesViaHTTP(ctx context.Context, cl client.Caller, bytes []byte) ([]byte, error) {
|
||||
params := map[string]interface{}{
|
||||
"arg": bytes,
|
||||
}
|
||||
@@ -173,7 +172,7 @@ func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) {
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) {
|
||||
func echoDataBytesViaHTTP(ctx context.Context, cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) {
|
||||
params := map[string]interface{}{
|
||||
"arg": bytes,
|
||||
}
|
||||
@@ -184,24 +183,24 @@ func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.Hex
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
func testWithHTTPClient(t *testing.T, cl client.HTTPClient) {
|
||||
func testWithHTTPClient(ctx context.Context, t *testing.T, cl client.HTTPClient) {
|
||||
val := testVal
|
||||
got, err := echoViaHTTP(cl, val)
|
||||
got, err := echoViaHTTP(ctx, cl, val)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got, val)
|
||||
|
||||
val2 := randBytes(t)
|
||||
got2, err := echoBytesViaHTTP(cl, val2)
|
||||
got2, err := echoBytesViaHTTP(ctx, cl, val2)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got2, val2)
|
||||
|
||||
val3 := tmbytes.HexBytes(randBytes(t))
|
||||
got3, err := echoDataBytesViaHTTP(cl, val3)
|
||||
got3, err := echoDataBytesViaHTTP(ctx, cl, val3)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got3, val3)
|
||||
|
||||
val4 := mrand.Intn(10000)
|
||||
got4, err := echoIntViaHTTP(cl, val4)
|
||||
got4, err := echoIntViaHTTP(ctx, cl, val4)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got4, val4)
|
||||
}
|
||||
@@ -265,55 +264,70 @@ func testWithWSClient(t *testing.T, cl *client.WSClient) {
|
||||
//-------------
|
||||
|
||||
func TestServersAndClientsBasic(t *testing.T) {
|
||||
bctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
serverAddrs := [...]string{tcpAddr, unixAddr}
|
||||
for _, addr := range serverAddrs {
|
||||
cl1, err := client.NewURI(addr)
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using URI client", addr)
|
||||
testWithHTTPClient(t, cl1)
|
||||
t.Run(addr, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(bctx)
|
||||
defer cancel()
|
||||
|
||||
cl2, err := client.New(addr)
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using JSONRPC client", addr)
|
||||
testWithHTTPClient(t, cl2)
|
||||
cl1, err := client.NewURI(addr)
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using URI client", addr)
|
||||
testWithHTTPClient(ctx, t, cl1)
|
||||
|
||||
cl3, err := client.NewWS(addr, websocketEndpoint)
|
||||
require.Nil(t, err)
|
||||
cl3.SetLogger(log.TestingLogger())
|
||||
err = cl3.Start()
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using WS client", addr)
|
||||
testWithWSClient(t, cl3)
|
||||
err = cl3.Stop()
|
||||
require.NoError(t, err)
|
||||
cl2, err := client.New(addr)
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using JSONRPC client", addr)
|
||||
testWithHTTPClient(ctx, t, cl2)
|
||||
|
||||
cl3, err := client.NewWS(addr, websocketEndpoint)
|
||||
require.Nil(t, err)
|
||||
cl3.Logger = log.TestingLogger()
|
||||
err = cl3.Start(ctx)
|
||||
require.Nil(t, err)
|
||||
fmt.Printf("=== testing server on %s using WS client", addr)
|
||||
testWithWSClient(t, cl3)
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHexStringArg(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cl, err := client.NewURI(tcpAddr)
|
||||
require.Nil(t, err)
|
||||
// should NOT be handled as hex
|
||||
val := "0xabc"
|
||||
got, err := echoViaHTTP(cl, val)
|
||||
got, err := echoViaHTTP(ctx, cl, val)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got, val)
|
||||
}
|
||||
|
||||
func TestQuotedStringArg(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cl, err := client.NewURI(tcpAddr)
|
||||
require.Nil(t, err)
|
||||
// should NOT be unquoted
|
||||
val := "\"abc\""
|
||||
got, err := echoViaHTTP(cl, val)
|
||||
got, err := echoViaHTTP(ctx, cl, val)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, got, val)
|
||||
}
|
||||
|
||||
func TestWSNewWSRPCFunc(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
|
||||
require.Nil(t, err)
|
||||
cl.SetLogger(log.TestingLogger())
|
||||
err = cl.Start()
|
||||
cl.Logger = log.TestingLogger()
|
||||
err = cl.Start(ctx)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := cl.Stop(); err != nil {
|
||||
@@ -340,11 +354,14 @@ func TestWSNewWSRPCFunc(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWSHandlesArrayParams(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
|
||||
require.Nil(t, err)
|
||||
cl.SetLogger(log.TestingLogger())
|
||||
err = cl.Start()
|
||||
require.Nil(t, err)
|
||||
|
||||
cl.Logger = log.TestingLogger()
|
||||
require.Nil(t, cl.Start(ctx))
|
||||
t.Cleanup(func() {
|
||||
if err := cl.Stop(); err != nil {
|
||||
t.Error(err)
|
||||
@@ -370,10 +387,13 @@ func TestWSHandlesArrayParams(t *testing.T) {
|
||||
// TestWSClientPingPong checks that a client & server exchange pings
|
||||
// & pongs so connection stays alive.
|
||||
func TestWSClientPingPong(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
|
||||
require.Nil(t, err)
|
||||
cl.SetLogger(log.TestingLogger())
|
||||
err = cl.Start()
|
||||
cl.Logger = log.TestingLogger()
|
||||
err = cl.Start(ctx)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := cl.Stop(); err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -50,7 +51,13 @@ func DefaultConfig() *Config {
|
||||
// body size to config.MaxBodyBytes.
|
||||
//
|
||||
// NOTE: This function blocks - you may want to call it in a go-routine.
|
||||
func Serve(listener net.Listener, handler http.Handler, logger log.Logger, config *Config) error {
|
||||
func Serve(
|
||||
ctx context.Context,
|
||||
listener net.Listener,
|
||||
handler http.Handler,
|
||||
logger log.Logger,
|
||||
config *Config,
|
||||
) error {
|
||||
logger.Info(fmt.Sprintf("Starting RPC HTTP server on %s", listener.Addr()))
|
||||
s := &http.Server{
|
||||
Handler: RecoverAndLogHandler(maxBytesHandler{h: handler, n: config.MaxBodyBytes}, logger),
|
||||
@@ -58,9 +65,23 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
MaxHeaderBytes: config.MaxHeaderBytes,
|
||||
}
|
||||
err := s.Serve(listener)
|
||||
logger.Info("RPC HTTP server stopped", "err", err)
|
||||
return err
|
||||
sig := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_ = s.Shutdown(sctx)
|
||||
case <-sig:
|
||||
}
|
||||
}()
|
||||
|
||||
if err := s.Serve(listener); err != nil {
|
||||
logger.Info("RPC HTTP server stopped", "err", err)
|
||||
close(sig)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve creates a http.Server and calls ServeTLS with the given listener,
|
||||
@@ -69,6 +90,7 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi
|
||||
//
|
||||
// NOTE: This function blocks - you may want to call it in a go-routine.
|
||||
func ServeTLS(
|
||||
ctx context.Context,
|
||||
listener net.Listener,
|
||||
handler http.Handler,
|
||||
certFile, keyFile string,
|
||||
@@ -83,10 +105,23 @@ func ServeTLS(
|
||||
WriteTimeout: config.WriteTimeout,
|
||||
MaxHeaderBytes: config.MaxHeaderBytes,
|
||||
}
|
||||
err := s.ServeTLS(listener, certFile, keyFile)
|
||||
sig := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_ = s.Shutdown(sctx)
|
||||
case <-sig:
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Error("RPC HTTPS server stopped", "err", err)
|
||||
return err
|
||||
if err := s.ServeTLS(listener, certFile, keyFile); err != nil {
|
||||
logger.Error("RPC HTTPS server stopped", "err", err)
|
||||
close(sig)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteRPCResponseHTTPError marshals res as JSON (with indent) and writes it
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -27,6 +28,9 @@ type sampleResult struct {
|
||||
func TestMaxOpenConnections(t *testing.T) {
|
||||
const max = 5 // max simultaneous connections
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start the server.
|
||||
var open int32
|
||||
mux := http.NewServeMux()
|
||||
@@ -42,7 +46,7 @@ func TestMaxOpenConnections(t *testing.T) {
|
||||
l, err := Listen("tcp://127.0.0.1:0", max)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go Serve(l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests
|
||||
go Serve(ctx, l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests
|
||||
|
||||
// Make N GET calls to the server.
|
||||
attempts := max * 2
|
||||
@@ -80,10 +84,12 @@ func TestServeTLS(t *testing.T) {
|
||||
fmt.Fprint(w, "some body")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
chErr := make(chan error, 1)
|
||||
go func() {
|
||||
// FIXME This goroutine leaks
|
||||
chErr <- ServeTLS(ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig())
|
||||
chErr <- ServeTLS(ctx, ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig())
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
@@ -87,14 +87,16 @@ func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
// register connection
|
||||
logger := wm.logger.With("remote", wsConn.RemoteAddr())
|
||||
con := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...)
|
||||
wm.logger.Info("New websocket connection", "remote", con.remoteAddr)
|
||||
err = con.Start() // BLOCKING
|
||||
if err != nil {
|
||||
conn := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...)
|
||||
wm.logger.Info("New websocket connection", "remote", conn.remoteAddr)
|
||||
|
||||
// starting the conn is blocking
|
||||
if err = conn.Start(r.Context()); err != nil {
|
||||
wm.logger.Error("Failed to start connection", "err", err)
|
||||
return
|
||||
}
|
||||
if err := con.Stop(); err != nil {
|
||||
|
||||
if err := conn.Stop(); err != nil {
|
||||
wm.logger.Error("error while stopping connection", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -220,16 +222,16 @@ func ReadLimit(readLimit int64) func(*wsConnection) {
|
||||
}
|
||||
|
||||
// Start starts the client service routines and blocks until there is an error.
|
||||
func (wsc *wsConnection) Start() error {
|
||||
if err := wsc.RunState.Start(); err != nil {
|
||||
func (wsc *wsConnection) Start(ctx context.Context) error {
|
||||
if err := wsc.RunState.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
wsc.writeChan = make(chan rpctypes.RPCResponse, wsc.writeChanCapacity)
|
||||
|
||||
// Read subscriptions/unsubscriptions to events
|
||||
go wsc.readRoutine()
|
||||
go wsc.readRoutine(ctx)
|
||||
// Write responses, BLOCKING.
|
||||
wsc.writeRoutine()
|
||||
wsc.writeRoutine(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -259,8 +261,6 @@ func (wsc *wsConnection) GetRemoteAddr() string {
|
||||
// It implements WSRPCConnection. It is Goroutine-safe.
|
||||
func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) error {
|
||||
select {
|
||||
case <-wsc.Quit():
|
||||
return errors.New("connection was stopped")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case wsc.writeChan <- resp:
|
||||
@@ -271,9 +271,9 @@ func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPC
|
||||
// TryWriteRPCResponse attempts to push a response to the writeChan, but does
|
||||
// not block.
|
||||
// It implements WSRPCConnection. It is Goroutine-safe
|
||||
func (wsc *wsConnection) TryWriteRPCResponse(resp rpctypes.RPCResponse) bool {
|
||||
func (wsc *wsConnection) TryWriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) bool {
|
||||
select {
|
||||
case <-wsc.Quit():
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case wsc.writeChan <- resp:
|
||||
return true
|
||||
@@ -293,7 +293,7 @@ func (wsc *wsConnection) Context() context.Context {
|
||||
}
|
||||
|
||||
// Read from the socket and subscribe to or unsubscribe from events
|
||||
func (wsc *wsConnection) readRoutine() {
|
||||
func (wsc *wsConnection) readRoutine(ctx context.Context) {
|
||||
// readRoutine will block until response is written or WS connection is closed
|
||||
writeCtx := context.Background()
|
||||
|
||||
@@ -307,7 +307,7 @@ func (wsc *wsConnection) readRoutine() {
|
||||
if err := wsc.WriteRPCResponse(writeCtx, rpctypes.RPCInternalError(rpctypes.JSONRPCIntID(-1), err)); err != nil {
|
||||
wsc.Logger.Error("Error writing RPC response", "err", err)
|
||||
}
|
||||
go wsc.readRoutine()
|
||||
go wsc.readRoutine(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -317,7 +317,7 @@ func (wsc *wsConnection) readRoutine() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-wsc.Quit():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// reset deadline for every type of message (control or data)
|
||||
@@ -422,7 +422,7 @@ func (wsc *wsConnection) readRoutine() {
|
||||
}
|
||||
|
||||
// receives on a write channel and writes out on the socket
|
||||
func (wsc *wsConnection) writeRoutine() {
|
||||
func (wsc *wsConnection) writeRoutine(ctx context.Context) {
|
||||
pingTicker := time.NewTicker(wsc.pingPeriod)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
@@ -438,7 +438,7 @@ func (wsc *wsConnection) writeRoutine() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-wsc.Quit():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-wsc.readRoutineQuit: // error in readRoutine
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -29,8 +30,11 @@ func main() {
|
||||
logger = log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false)
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Stop upon receiving SIGTERM or CTRL-C.
|
||||
tmos.TrapSignal(logger, func() {})
|
||||
tmos.TrapSignal(ctx, logger, func() {})
|
||||
|
||||
rpcserver.RegisterRPCFuncs(mux, routes, logger)
|
||||
config := rpcserver.DefaultConfig()
|
||||
@@ -40,7 +44,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err = rpcserver.Serve(listener, mux, logger, config); err != nil {
|
||||
if err = rpcserver.Serve(ctx, listener, mux, logger, config); err != nil {
|
||||
logger.Error("rpc serve", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -253,7 +253,7 @@ type WSRPCConnection interface {
|
||||
// WriteRPCResponse writes the response onto connection (BLOCKING).
|
||||
WriteRPCResponse(context.Context, RPCResponse) error
|
||||
// TryWriteRPCResponse tries to write the response onto connection (NON-BLOCKING).
|
||||
TryWriteRPCResponse(RPCResponse) bool
|
||||
TryWriteRPCResponse(context.Context, RPCResponse) bool
|
||||
// Context returns the connection's context.
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user