service: plumb contexts to all (most) threads (#7363)

This continues the push of plumbing contexts through tendermint. I
attempted to find all goroutines in the production code (non-test) and
made sure that these threads would exit when their contexts were
canceled, and I believe this PR does that.
This commit is contained in:
Sam Kleinman
2021-12-02 16:38:38 -05:00
committed by GitHub
parent b3be1d7d7a
commit 8a991e288c
64 changed files with 964 additions and 727 deletions

View File

@@ -24,6 +24,7 @@ type WSOptions struct {
ReadWait time.Duration // deadline for any read op
WriteWait time.Duration // deadline for any write op
PingPeriod time.Duration // frequency with which pings are sent
SkipMetrics bool // do not keep metrics for ping/pong latency
}
// DefaultWSOptions returns default WS options.
@@ -117,8 +118,6 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e
Address: parsedURL.GetTrimmedHostWithPath(),
Dialer: dialFn,
Endpoint: endpoint,
PingPongLatencyTimer: metrics.NewTimer(),
maxReconnectAttempts: opts.MaxReconnectAttempts,
readWait: opts.ReadWait,
writeWait: opts.WriteWait,
@@ -127,6 +126,14 @@ func NewWSWithOptions(remoteAddr, endpoint string, opts WSOptions) (*WSClient, e
// sentIDs: make(map[types.JSONRPCIntID]bool),
}
switch opts.SkipMetrics {
case true:
c.PingPongLatencyTimer = metrics.NilTimer{}
case false:
c.PingPongLatencyTimer = metrics.NewTimer()
}
return c, nil
}
@@ -143,8 +150,8 @@ func (c *WSClient) String() string {
}
// Start dials the specified service address and starts the I/O routines.
func (c *WSClient) Start() error {
if err := c.RunState.Start(); err != nil {
func (c *WSClient) Start(ctx context.Context) error {
if err := c.RunState.Start(ctx); err != nil {
return err
}
err := c.dial()
@@ -162,8 +169,8 @@ func (c *WSClient) Start() error {
// channel is unbuffered.
c.backlog = make(chan rpctypes.RPCRequest, 1)
c.startReadWriteRoutines()
go c.reconnectRoutine()
c.startReadWriteRoutines(ctx)
go c.reconnectRoutine(ctx)
return nil
}
@@ -173,6 +180,7 @@ func (c *WSClient) Stop() error {
if err := c.RunState.Stop(); err != nil {
return err
}
// only close user-facing channels when we can't write to them
c.wg.Wait()
close(c.ResponsesCh)
@@ -253,7 +261,7 @@ func (c *WSClient) dial() error {
// reconnect tries to redial up to maxReconnectAttempts with exponential
// backoff.
func (c *WSClient) reconnect() error {
func (c *WSClient) reconnect(ctx context.Context) error {
attempt := uint(0)
c.mtx.Lock()
@@ -265,13 +273,21 @@ func (c *WSClient) reconnect() error {
c.mtx.Unlock()
}()
timer := time.NewTimer(0)
defer timer.Stop()
for {
// nolint:gosec // G404: Use of weak random number generator
jitter := time.Duration(mrand.Float64() * float64(time.Second)) // 1s == (1e9 ns)
backoffDuration := jitter + ((1 << attempt) * time.Second)
c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration)
time.Sleep(backoffDuration)
timer.Reset(backoffDuration)
select {
case <-ctx.Done():
return nil
case <-timer.C:
}
err := c.dial()
if err != nil {
@@ -292,11 +308,11 @@ func (c *WSClient) reconnect() error {
}
}
func (c *WSClient) startReadWriteRoutines() {
func (c *WSClient) startReadWriteRoutines(ctx context.Context) {
c.wg.Add(2)
c.readRoutineQuit = make(chan struct{})
go c.readRoutine()
go c.writeRoutine()
go c.readRoutine(ctx)
go c.writeRoutine(ctx)
}
func (c *WSClient) processBacklog() error {
@@ -320,13 +336,15 @@ func (c *WSClient) processBacklog() error {
return nil
}
func (c *WSClient) reconnectRoutine() {
func (c *WSClient) reconnectRoutine(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case originalError := <-c.reconnectAfter:
// wait until writeRoutine and readRoutine finish
c.wg.Wait()
if err := c.reconnect(); err != nil {
if err := c.reconnect(ctx); err != nil {
c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError)
if err = c.Stop(); err != nil {
c.Logger.Error("failed to stop conn", "error", err)
@@ -338,6 +356,8 @@ func (c *WSClient) reconnectRoutine() {
LOOP:
for {
select {
case <-ctx.Done():
return
case <-c.reconnectAfter:
default:
break LOOP
@@ -345,18 +365,15 @@ func (c *WSClient) reconnectRoutine() {
}
err := c.processBacklog()
if err == nil {
c.startReadWriteRoutines()
c.startReadWriteRoutines(ctx)
}
case <-c.Quit():
return
}
}
}
// The client ensures that there is at most one writer to a connection by
// executing all writes from this goroutine.
func (c *WSClient) writeRoutine() {
func (c *WSClient) writeRoutine(ctx context.Context) {
var ticker *time.Ticker
if c.pingPeriod > 0 {
// ticker with a predefined period
@@ -408,7 +425,7 @@ func (c *WSClient) writeRoutine() {
c.Logger.Debug("sent ping")
case <-c.readRoutineQuit:
return
case <-c.Quit():
case <-ctx.Done():
if err := c.conn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
@@ -422,7 +439,7 @@ func (c *WSClient) writeRoutine() {
// The client ensures that there is at most one reader to a connection by
// executing all reads from this goroutine.
func (c *WSClient) readRoutine() {
func (c *WSClient) readRoutine(ctx context.Context) {
defer func() {
c.conn.Close()
// err != nil {
@@ -494,7 +511,8 @@ func (c *WSClient) readRoutine() {
c.Logger.Info("got response", "id", response.ID, "result", response.Result)
select {
case <-c.Quit():
case <-ctx.Done():
return
case c.ResponsesCh <- response:
}
}

View File

@@ -5,10 +5,11 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"runtime"
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
@@ -64,25 +65,26 @@ func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
var wg sync.WaitGroup
t.Cleanup(leaktest.Check(t))
// start server
h := &myHandler{}
s := httptest.NewServer(h)
defer s.Close()
c := startClient(t, "//"+s.Listener.Addr().String())
defer c.Stop() // nolint:errcheck // ignore for tests
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
wg.Add(1)
go callWgDoneOnResult(t, c, &wg)
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
go handleResponses(ctx, t, c)
h.mtx.Lock()
h.closeConnAfterRead = true
h.mtx.Unlock()
// results in WS read error, no send retry because write succeeded
call(t, "a", c)
call(ctx, t, "a", c)
// expect to reconnect almost immediately
time.Sleep(10 * time.Millisecond)
@@ -91,23 +93,23 @@ func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
h.mtx.Unlock()
// should succeed
call(t, "b", c)
wg.Wait()
call(ctx, t, "b", c)
}
func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
var wg sync.WaitGroup
t.Cleanup(leaktest.Check(t))
// start server
h := &myHandler{}
s := httptest.NewServer(h)
defer s.Close()
c := startClient(t, "//"+s.Listener.Addr().String())
defer c.Stop() // nolint:errcheck // ignore for tests
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg.Add(2)
go callWgDoneOnResult(t, c, &wg)
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
go handleResponses(ctx, t, c)
// hacky way to abort the connection before write
if err := c.conn.Close(); err != nil {
@@ -115,30 +117,32 @@ func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
}
// results in WS write error, the client should resend on reconnect
call(t, "a", c)
call(ctx, t, "a", c)
// expect to reconnect almost immediately
time.Sleep(10 * time.Millisecond)
// should succeed
call(t, "b", c)
wg.Wait()
call(ctx, t, "b", c)
}
func TestWSClientReconnectFailure(t *testing.T) {
t.Cleanup(leaktest.Check(t))
// start server
h := &myHandler{}
s := httptest.NewServer(h)
c := startClient(t, "//"+s.Listener.Addr().String())
defer c.Stop() // nolint:errcheck // ignore for tests
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
go func() {
for {
select {
case <-c.ResponsesCh:
case <-c.Quit():
case <-ctx.Done():
return
}
}
@@ -152,9 +156,9 @@ func TestWSClientReconnectFailure(t *testing.T) {
// results in WS write error
// provide timeout to avoid blocking
ctx, cancel := context.WithTimeout(context.Background(), wsCallTimeout)
cctx, cancel := context.WithTimeout(ctx, wsCallTimeout)
defer cancel()
if err := c.Call(ctx, "a", make(map[string]interface{})); err != nil {
if err := c.Call(cctx, "a", make(map[string]interface{})); err != nil {
t.Error(err)
}
@@ -164,7 +168,7 @@ func TestWSClientReconnectFailure(t *testing.T) {
done := make(chan struct{})
go func() {
// client should block on this
call(t, "b", c)
call(ctx, t, "b", c)
close(done)
}()
@@ -178,44 +182,68 @@ func TestWSClientReconnectFailure(t *testing.T) {
}
func TestNotBlockingOnStop(t *testing.T) {
timeout := 2 * time.Second
t.Cleanup(leaktest.Check(t))
timeout := 3 * time.Second
s := httptest.NewServer(&myHandler{})
c := startClient(t, "//"+s.Listener.Addr().String())
c.Call(context.Background(), "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests
defer s.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := startClient(ctx, t, "//"+s.Listener.Addr().String())
c.Call(ctx, "a", make(map[string]interface{})) // nolint:errcheck // ignore for tests
// Let the readRoutine get around to blocking
time.Sleep(time.Second)
passCh := make(chan struct{})
go func() {
// Unless we have a non-blocking write to ResponsesCh from readRoutine
// this blocks forever ont the waitgroup
err := c.Stop()
require.NoError(t, err)
passCh <- struct{}{}
cancel()
require.NoError(t, c.Stop())
select {
case <-ctx.Done():
case passCh <- struct{}{}:
}
}()
runtime.Gosched() // hacks: force context switch
select {
case <-passCh:
// Pass
case <-time.After(timeout):
t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
timeout.Seconds())
if c.IsRunning() {
t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
timeout.Seconds())
}
}
}
func startClient(t *testing.T, addr string) *WSClient {
c, err := NewWS(addr, "/websocket")
func startClient(ctx context.Context, t *testing.T, addr string) *WSClient {
t.Helper()
opts := DefaultWSOptions()
opts.SkipMetrics = true
c, err := NewWSWithOptions(addr, "/websocket", opts)
require.Nil(t, err)
err = c.Start()
err = c.Start(ctx)
require.Nil(t, err)
c.SetLogger(log.TestingLogger())
c.Logger = log.TestingLogger()
return c
}
func call(t *testing.T, method string, c *WSClient) {
err := c.Call(context.Background(), method, make(map[string]interface{}))
require.NoError(t, err)
func call(ctx context.Context, t *testing.T, method string, c *WSClient) {
t.Helper()
err := c.Call(ctx, method, make(map[string]interface{}))
if ctx.Err() == nil {
require.NoError(t, err)
}
}
func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
func handleResponses(ctx context.Context, t *testing.T, c *WSClient) {
t.Helper()
for {
select {
case resp := <-c.ResponsesCh:
@@ -224,9 +252,9 @@ func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
return
}
if resp.Result != nil {
wg.Done()
return
}
case <-c.Quit():
case <-ctx.Done():
return
}
}

View File

@@ -35,10 +35,6 @@ const (
testVal = "acbd"
)
var (
ctx = context.Background()
)
type ResultEcho struct {
Value string `json:"value"`
}
@@ -85,13 +81,16 @@ func EchoDataBytesResult(ctx *rpctypes.Context, v tmbytes.HexBytes) (*ResultEcho
}
func TestMain(m *testing.M) {
setup()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
setup(ctx)
code := m.Run()
os.Exit(code)
}
// launch unix and tcp servers
func setup() {
func setup(ctx context.Context) {
logger := log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false)
cmd := exec.Command("rm", "-f", unixSocket)
@@ -115,7 +114,7 @@ func setup() {
panic(err)
}
go func() {
if err := server.Serve(listener1, mux, tcpLogger, config); err != nil {
if err := server.Serve(ctx, listener1, mux, tcpLogger, config); err != nil {
panic(err)
}
}()
@@ -131,7 +130,7 @@ func setup() {
panic(err)
}
go func() {
if err := server.Serve(listener2, mux2, unixLogger, config); err != nil {
if err := server.Serve(ctx, listener2, mux2, unixLogger, config); err != nil {
panic(err)
}
}()
@@ -140,7 +139,7 @@ func setup() {
time.Sleep(time.Second * 2)
}
func echoViaHTTP(cl client.Caller, val string) (string, error) {
func echoViaHTTP(ctx context.Context, cl client.Caller, val string) (string, error) {
params := map[string]interface{}{
"arg": val,
}
@@ -151,7 +150,7 @@ func echoViaHTTP(cl client.Caller, val string) (string, error) {
return result.Value, nil
}
func echoIntViaHTTP(cl client.Caller, val int) (int, error) {
func echoIntViaHTTP(ctx context.Context, cl client.Caller, val int) (int, error) {
params := map[string]interface{}{
"arg": val,
}
@@ -162,7 +161,7 @@ func echoIntViaHTTP(cl client.Caller, val int) (int, error) {
return result.Value, nil
}
func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) {
func echoBytesViaHTTP(ctx context.Context, cl client.Caller, bytes []byte) ([]byte, error) {
params := map[string]interface{}{
"arg": bytes,
}
@@ -173,7 +172,7 @@ func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) {
return result.Value, nil
}
func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) {
func echoDataBytesViaHTTP(ctx context.Context, cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) {
params := map[string]interface{}{
"arg": bytes,
}
@@ -184,24 +183,24 @@ func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.Hex
return result.Value, nil
}
func testWithHTTPClient(t *testing.T, cl client.HTTPClient) {
func testWithHTTPClient(ctx context.Context, t *testing.T, cl client.HTTPClient) {
val := testVal
got, err := echoViaHTTP(cl, val)
got, err := echoViaHTTP(ctx, cl, val)
require.Nil(t, err)
assert.Equal(t, got, val)
val2 := randBytes(t)
got2, err := echoBytesViaHTTP(cl, val2)
got2, err := echoBytesViaHTTP(ctx, cl, val2)
require.Nil(t, err)
assert.Equal(t, got2, val2)
val3 := tmbytes.HexBytes(randBytes(t))
got3, err := echoDataBytesViaHTTP(cl, val3)
got3, err := echoDataBytesViaHTTP(ctx, cl, val3)
require.Nil(t, err)
assert.Equal(t, got3, val3)
val4 := mrand.Intn(10000)
got4, err := echoIntViaHTTP(cl, val4)
got4, err := echoIntViaHTTP(ctx, cl, val4)
require.Nil(t, err)
assert.Equal(t, got4, val4)
}
@@ -265,55 +264,70 @@ func testWithWSClient(t *testing.T, cl *client.WSClient) {
//-------------
func TestServersAndClientsBasic(t *testing.T) {
bctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverAddrs := [...]string{tcpAddr, unixAddr}
for _, addr := range serverAddrs {
cl1, err := client.NewURI(addr)
require.Nil(t, err)
fmt.Printf("=== testing server on %s using URI client", addr)
testWithHTTPClient(t, cl1)
t.Run(addr, func(t *testing.T) {
ctx, cancel := context.WithCancel(bctx)
defer cancel()
cl2, err := client.New(addr)
require.Nil(t, err)
fmt.Printf("=== testing server on %s using JSONRPC client", addr)
testWithHTTPClient(t, cl2)
cl1, err := client.NewURI(addr)
require.Nil(t, err)
fmt.Printf("=== testing server on %s using URI client", addr)
testWithHTTPClient(ctx, t, cl1)
cl3, err := client.NewWS(addr, websocketEndpoint)
require.Nil(t, err)
cl3.SetLogger(log.TestingLogger())
err = cl3.Start()
require.Nil(t, err)
fmt.Printf("=== testing server on %s using WS client", addr)
testWithWSClient(t, cl3)
err = cl3.Stop()
require.NoError(t, err)
cl2, err := client.New(addr)
require.Nil(t, err)
fmt.Printf("=== testing server on %s using JSONRPC client", addr)
testWithHTTPClient(ctx, t, cl2)
cl3, err := client.NewWS(addr, websocketEndpoint)
require.Nil(t, err)
cl3.Logger = log.TestingLogger()
err = cl3.Start(ctx)
require.Nil(t, err)
fmt.Printf("=== testing server on %s using WS client", addr)
testWithWSClient(t, cl3)
cancel()
})
}
}
func TestHexStringArg(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cl, err := client.NewURI(tcpAddr)
require.Nil(t, err)
// should NOT be handled as hex
val := "0xabc"
got, err := echoViaHTTP(cl, val)
got, err := echoViaHTTP(ctx, cl, val)
require.Nil(t, err)
assert.Equal(t, got, val)
}
func TestQuotedStringArg(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cl, err := client.NewURI(tcpAddr)
require.Nil(t, err)
// should NOT be unquoted
val := "\"abc\""
got, err := echoViaHTTP(cl, val)
got, err := echoViaHTTP(ctx, cl, val)
require.Nil(t, err)
assert.Equal(t, got, val)
}
func TestWSNewWSRPCFunc(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
require.Nil(t, err)
cl.SetLogger(log.TestingLogger())
err = cl.Start()
cl.Logger = log.TestingLogger()
err = cl.Start(ctx)
require.Nil(t, err)
t.Cleanup(func() {
if err := cl.Stop(); err != nil {
@@ -340,11 +354,14 @@ func TestWSNewWSRPCFunc(t *testing.T) {
}
func TestWSHandlesArrayParams(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
require.Nil(t, err)
cl.SetLogger(log.TestingLogger())
err = cl.Start()
require.Nil(t, err)
cl.Logger = log.TestingLogger()
require.Nil(t, cl.Start(ctx))
t.Cleanup(func() {
if err := cl.Stop(); err != nil {
t.Error(err)
@@ -370,10 +387,13 @@ func TestWSHandlesArrayParams(t *testing.T) {
// TestWSClientPingPong checks that a client & server exchange pings
// & pongs so connection stays alive.
func TestWSClientPingPong(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cl, err := client.NewWS(tcpAddr, websocketEndpoint)
require.Nil(t, err)
cl.SetLogger(log.TestingLogger())
err = cl.Start()
cl.Logger = log.TestingLogger()
err = cl.Start(ctx)
require.Nil(t, err)
t.Cleanup(func() {
if err := cl.Stop(); err != nil {

View File

@@ -3,6 +3,7 @@ package server
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
@@ -50,7 +51,13 @@ func DefaultConfig() *Config {
// body size to config.MaxBodyBytes.
//
// NOTE: This function blocks - you may want to call it in a go-routine.
func Serve(listener net.Listener, handler http.Handler, logger log.Logger, config *Config) error {
func Serve(
ctx context.Context,
listener net.Listener,
handler http.Handler,
logger log.Logger,
config *Config,
) error {
logger.Info(fmt.Sprintf("Starting RPC HTTP server on %s", listener.Addr()))
s := &http.Server{
Handler: RecoverAndLogHandler(maxBytesHandler{h: handler, n: config.MaxBodyBytes}, logger),
@@ -58,9 +65,23 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi
WriteTimeout: config.WriteTimeout,
MaxHeaderBytes: config.MaxHeaderBytes,
}
err := s.Serve(listener)
logger.Info("RPC HTTP server stopped", "err", err)
return err
sig := make(chan struct{})
go func() {
select {
case <-ctx.Done():
sctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_ = s.Shutdown(sctx)
case <-sig:
}
}()
if err := s.Serve(listener); err != nil {
logger.Info("RPC HTTP server stopped", "err", err)
close(sig)
return err
}
return nil
}
// Serve creates a http.Server and calls ServeTLS with the given listener,
@@ -69,6 +90,7 @@ func Serve(listener net.Listener, handler http.Handler, logger log.Logger, confi
//
// NOTE: This function blocks - you may want to call it in a go-routine.
func ServeTLS(
ctx context.Context,
listener net.Listener,
handler http.Handler,
certFile, keyFile string,
@@ -83,10 +105,23 @@ func ServeTLS(
WriteTimeout: config.WriteTimeout,
MaxHeaderBytes: config.MaxHeaderBytes,
}
err := s.ServeTLS(listener, certFile, keyFile)
sig := make(chan struct{})
go func() {
select {
case <-ctx.Done():
sctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_ = s.Shutdown(sctx)
case <-sig:
}
}()
logger.Error("RPC HTTPS server stopped", "err", err)
return err
if err := s.ServeTLS(listener, certFile, keyFile); err != nil {
logger.Error("RPC HTTPS server stopped", "err", err)
close(sig)
return err
}
return nil
}
// WriteRPCResponseHTTPError marshals res as JSON (with indent) and writes it

View File

@@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/tls"
"errors"
"fmt"
@@ -27,6 +28,9 @@ type sampleResult struct {
func TestMaxOpenConnections(t *testing.T) {
const max = 5 // max simultaneous connections
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the server.
var open int32
mux := http.NewServeMux()
@@ -42,7 +46,7 @@ func TestMaxOpenConnections(t *testing.T) {
l, err := Listen("tcp://127.0.0.1:0", max)
require.NoError(t, err)
defer l.Close()
go Serve(l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests
go Serve(ctx, l, mux, log.TestingLogger(), config) //nolint:errcheck // ignore for tests
// Make N GET calls to the server.
attempts := max * 2
@@ -80,10 +84,12 @@ func TestServeTLS(t *testing.T) {
fmt.Fprint(w, "some body")
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
chErr := make(chan error, 1)
go func() {
// FIXME This goroutine leaks
chErr <- ServeTLS(ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig())
chErr <- ServeTLS(ctx, ln, mux, "test.crt", "test.key", log.TestingLogger(), DefaultConfig())
}()
select {

View File

@@ -87,14 +87,16 @@ func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Requ
// register connection
logger := wm.logger.With("remote", wsConn.RemoteAddr())
con := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...)
wm.logger.Info("New websocket connection", "remote", con.remoteAddr)
err = con.Start() // BLOCKING
if err != nil {
conn := newWSConnection(wsConn, wm.funcMap, logger, wm.wsConnOptions...)
wm.logger.Info("New websocket connection", "remote", conn.remoteAddr)
// starting the conn is blocking
if err = conn.Start(r.Context()); err != nil {
wm.logger.Error("Failed to start connection", "err", err)
return
}
if err := con.Stop(); err != nil {
if err := conn.Stop(); err != nil {
wm.logger.Error("error while stopping connection", "error", err)
}
}
@@ -220,16 +222,16 @@ func ReadLimit(readLimit int64) func(*wsConnection) {
}
// Start starts the client service routines and blocks until there is an error.
func (wsc *wsConnection) Start() error {
if err := wsc.RunState.Start(); err != nil {
func (wsc *wsConnection) Start(ctx context.Context) error {
if err := wsc.RunState.Start(ctx); err != nil {
return err
}
wsc.writeChan = make(chan rpctypes.RPCResponse, wsc.writeChanCapacity)
// Read subscriptions/unsubscriptions to events
go wsc.readRoutine()
go wsc.readRoutine(ctx)
// Write responses, BLOCKING.
wsc.writeRoutine()
wsc.writeRoutine(ctx)
return nil
}
@@ -259,8 +261,6 @@ func (wsc *wsConnection) GetRemoteAddr() string {
// It implements WSRPCConnection. It is Goroutine-safe.
func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) error {
select {
case <-wsc.Quit():
return errors.New("connection was stopped")
case <-ctx.Done():
return ctx.Err()
case wsc.writeChan <- resp:
@@ -271,9 +271,9 @@ func (wsc *wsConnection) WriteRPCResponse(ctx context.Context, resp rpctypes.RPC
// TryWriteRPCResponse attempts to push a response to the writeChan, but does
// not block.
// It implements WSRPCConnection. It is Goroutine-safe
func (wsc *wsConnection) TryWriteRPCResponse(resp rpctypes.RPCResponse) bool {
func (wsc *wsConnection) TryWriteRPCResponse(ctx context.Context, resp rpctypes.RPCResponse) bool {
select {
case <-wsc.Quit():
case <-ctx.Done():
return false
case wsc.writeChan <- resp:
return true
@@ -293,7 +293,7 @@ func (wsc *wsConnection) Context() context.Context {
}
// Read from the socket and subscribe to or unsubscribe from events
func (wsc *wsConnection) readRoutine() {
func (wsc *wsConnection) readRoutine(ctx context.Context) {
// readRoutine will block until response is written or WS connection is closed
writeCtx := context.Background()
@@ -307,7 +307,7 @@ func (wsc *wsConnection) readRoutine() {
if err := wsc.WriteRPCResponse(writeCtx, rpctypes.RPCInternalError(rpctypes.JSONRPCIntID(-1), err)); err != nil {
wsc.Logger.Error("Error writing RPC response", "err", err)
}
go wsc.readRoutine()
go wsc.readRoutine(ctx)
}
}()
@@ -317,7 +317,7 @@ func (wsc *wsConnection) readRoutine() {
for {
select {
case <-wsc.Quit():
case <-ctx.Done():
return
default:
// reset deadline for every type of message (control or data)
@@ -422,7 +422,7 @@ func (wsc *wsConnection) readRoutine() {
}
// receives on a write channel and writes out on the socket
func (wsc *wsConnection) writeRoutine() {
func (wsc *wsConnection) writeRoutine(ctx context.Context) {
pingTicker := time.NewTicker(wsc.pingPeriod)
defer pingTicker.Stop()
@@ -438,7 +438,7 @@ func (wsc *wsConnection) writeRoutine() {
for {
select {
case <-wsc.Quit():
case <-ctx.Done():
return
case <-wsc.readRoutineQuit: // error in readRoutine
return

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"net/http"
"os"
@@ -29,8 +30,11 @@ func main() {
logger = log.MustNewDefaultLogger(log.LogFormatPlain, log.LogLevelInfo, false)
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Stop upon receiving SIGTERM or CTRL-C.
tmos.TrapSignal(logger, func() {})
tmos.TrapSignal(ctx, logger, func() {})
rpcserver.RegisterRPCFuncs(mux, routes, logger)
config := rpcserver.DefaultConfig()
@@ -40,7 +44,7 @@ func main() {
os.Exit(1)
}
if err = rpcserver.Serve(listener, mux, logger, config); err != nil {
if err = rpcserver.Serve(ctx, listener, mux, logger, config); err != nil {
logger.Error("rpc serve", "err", err)
os.Exit(1)
}

View File

@@ -253,7 +253,7 @@ type WSRPCConnection interface {
// WriteRPCResponse writes the response onto connection (BLOCKING).
WriteRPCResponse(context.Context, RPCResponse) error
// TryWriteRPCResponse tries to write the response onto connection (NON-BLOCKING).
TryWriteRPCResponse(RPCResponse) bool
TryWriteRPCResponse(context.Context, RPCResponse) bool
// Context returns the connection's context.
Context() context.Context
}