diff --git a/abci/client/socket_client.go b/abci/client/socket_client.go index 8904d557d..7dfcf76cc 100644 --- a/abci/client/socket_client.go +++ b/abci/client/socket_client.go @@ -112,6 +112,11 @@ func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer case <-ctx.Done(): return case reqres := <-cli.reqQueue: + // N.B. We must enqueue before sending out the request, otherwise the + // server may reply before we do it, and the receiver will fail for an + // unsolicited reply. + cli.trackRequest(reqres) + if err := types.WriteMessage(reqres.Request, bw); err != nil { cli.stopForError(fmt.Errorf("write to buffer: %w", err)) return @@ -121,8 +126,6 @@ func (cli *socketClient) sendRequestsRoutine(ctx context.Context, conn io.Writer cli.stopForError(fmt.Errorf("flush buffer: %w", err)) return } - - cli.trackRequest(reqres) } } } @@ -155,13 +158,14 @@ func (cli *socketClient) recvResponseRoutine(ctx context.Context, conn io.Reader } func (cli *socketClient) trackRequest(reqres *requestAndResponse) { - cli.mtx.Lock() - defer cli.mtx.Unlock() - + // N.B. We must NOT hold the client state lock while checking this, or we + // may deadlock with shutdown. if !cli.IsRunning() { return } + cli.mtx.Lock() + defer cli.mtx.Unlock() cli.reqSent.PushBack(reqres) } diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index 09ac3f2c8..41a34bde7 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -58,7 +58,7 @@ func (app *appConnTest) Info(ctx context.Context, req *types.RequestInfo) (*type var SOCKET = "socket" func TestEcho(t *testing.T) { - sockPath := fmt.Sprintf("unix:///tmp/echo_%v.sock", tmrand.Str(6)) + sockPath := fmt.Sprintf("unix://%s/echo_%v.sock", t.TempDir(), tmrand.Str(6)) logger := log.NewNopLogger() client, err := abciclient.NewClient(logger, sockPath, SOCKET, true) if err != nil { @@ -98,7 +98,7 @@ func TestEcho(t *testing.T) { func BenchmarkEcho(b *testing.B) { b.StopTimer() // Initialize - sockPath := fmt.Sprintf("unix:///tmp/echo_%v.sock", tmrand.Str(6)) + sockPath := fmt.Sprintf("unix://%s/echo_%v.sock", b.TempDir(), tmrand.Str(6)) logger := log.NewNopLogger() client, err := abciclient.NewClient(logger, sockPath, SOCKET, true) if err != nil { @@ -146,7 +146,7 @@ func TestInfo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sockPath := fmt.Sprintf("unix:///tmp/echo_%v.sock", tmrand.Str(6)) + sockPath := fmt.Sprintf("unix://%s/echo_%v.sock", t.TempDir(), tmrand.Str(6)) logger := log.NewNopLogger() client, err := abciclient.NewClient(logger, sockPath, SOCKET, true) if err != nil {