diff --git a/internal/grid/connection.go b/internal/grid/connection.go index b9d7f893e..188ee04e1 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -1587,6 +1587,18 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { c.clientPingInterval = args[0].(time.Duration) case debugAddToDeadline: c.addDeadline = args[0].(time.Duration) + case debugIsOutgoingClosed: + // params: muxID uint64, isClosed func(bool) + muxID := args[0].(uint64) + resp := args[1].(func(b bool)) + mid, ok := c.outgoing.Load(muxID) + if !ok || mid == nil { + resp(true) + return + } + mid.respMu.Lock() + resp(mid.closed) + mid.respMu.Unlock() } } diff --git a/internal/grid/debug.go b/internal/grid/debug.go index eddb577e7..0172f87e2 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -49,6 +49,7 @@ const ( debugSetConnPingDuration debugSetClientPingDuration debugAddToDeadline + debugIsOutgoingClosed ) // TestGrid contains a grid of servers for testing purposes. diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go index 5c9942c9f..75ce9d35d 100644 --- a/internal/grid/grid_test.go +++ b/internal/grid/grid_test.go @@ -372,6 +372,12 @@ func TestStreamSuite(t *testing.T) { assertNoActive(t, connRemoteLocal) assertNoActive(t, connLocalToRemote) }) + t.Run("testServerStreamResponseBlocked", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamResponseBlocked(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) } func testStreamRoundtrip(t *testing.T, local, remote *Manager) { @@ -929,6 +935,96 @@ func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) { t.Log("EOF.", payloads, " Roundtrips:", time.Since(start)) } +// testServerStreamResponseBlocked will test if server can handle a blocked response stream +func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + serverSent := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr { + // Send many responses. + // Test that this doesn't block. + for i := byte(0); i < 100; i++ { + select { + case resp <- []byte{i}: + // ok + case <-ctx.Done(): + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + } + if i == 1 { + close(serverSent) + } + } + return nil + }, + OutCapacity: 1, + InCapacity: 0, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server to send the first response. + <-serverSent + + // Read back from the stream and block. + nowBlocking := make(chan struct{}) + stopBlocking := make(chan struct{}) + defer close(stopBlocking) + go func() { + st.Results(func(b []byte) error { + close(nowBlocking) + // Block until test is done. + <-stopBlocking + return nil + }) + }() + + <-nowBlocking + // Wait for the receiver channel to fill. + for len(st.responses) != cap(st.responses) { + time.Sleep(time.Millisecond) + } + cancel() + <-serverCanceled + local.debugMsg(debugIsOutgoingClosed, st.muxID, func(closed bool) { + if !closed { + t.Error("expected outgoing closed") + } else { + t.Log("outgoing was closed") + } + }) + + // Drain responses and check if error propagated. + err = st.Results(func(b []byte) error { + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Error("expected context.Canceled, got", err) + } +} + func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index 9425a86a6..1b380a341 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -50,6 +50,7 @@ type muxClient struct { deadline time.Duration outBlock chan struct{} subroute *subHandlerID + respErr atomic.Pointer[error] } // Response is a response from the server. @@ -250,25 +251,52 @@ func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []b // Spawn simple disconnect if requests == nil { - start := time.Now() - go m.handleOneWayStream(start, responseCh, responses) - return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn}, nil + go m.handleOneWayStream(responseCh, responses) + return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } // Deliver responses and send unblocks back to the server. go m.handleTwowayResponses(responseCh, responses) go m.handleTwowayRequests(responses, requests) - return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn}, nil + return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } -func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Response, respServer <-chan Response) { +func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) { + m.respMu.Lock() + defer m.respMu.Unlock() + if !m.closed { + m.respErr.Store(&err) + // Do not block. + select { + case respHandler <- Response{Err: err}: + xioutil.SafeClose(respHandler) + default: + go func() { + respHandler <- Response{Err: err} + xioutil.SafeClose(respHandler) + }() + } + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) + m.closed = true + } +} + +// respHandler +func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) { if debugPrint { + start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() } - defer xioutil.SafeClose(respHandler) + defer func() { + // addErrorNonBlockingClose will close the response channel + // - maybe async, so we shouldn't do it here. + if m.respErr.Load() == nil { + xioutil.SafeClose(respHandler) + } + }() var pingTimer <-chan time.Time if m.deadline == 0 || m.deadline > clientPingInterval { ticker := time.NewTicker(clientPingInterval) @@ -283,13 +311,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.respMu.Lock() - defer m.respMu.Unlock() // We always return in this path. - if !m.closed { - respHandler <- Response{Err: context.Cause(m.ctx)} - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - m.closeLocked() - } + m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx)) return case resp, ok := <-respServer: if !ok { @@ -308,13 +330,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo } case <-pingTimer: if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 { - m.respMu.Lock() - defer m.respMu.Unlock() // We always return in this path. - if !m.closed { - respHandler <- Response{Err: ErrDisconnected} - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - m.closeLocked() - } + m.addErrorNonBlockingClose(respHandler, ErrDisconnected) return } // Send new ping. @@ -323,19 +339,21 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo } } -func (m *muxClient) handleTwowayResponses(responseCh chan Response, responses chan Response) { +// responseCh is the channel to that goes to the requester. +// internalResp is the channel that comes from the server. +func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { defer m.parent.deleteMux(false, m.MuxID) defer xioutil.SafeClose(responseCh) - for resp := range responses { + for resp := range internalResp { responseCh <- resp m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) } } -func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests chan []byte) { +func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { var errState bool - start := time.Now() if debugPrint { + start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() @@ -343,19 +361,22 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha // Listen for client messages. for { + if errState { + go func() { + // Drain requests. + for range requests { + } + }() + return + } select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.respMu.Lock() - defer m.respMu.Unlock() - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - if !m.closed { - responses <- Response{Err: context.Cause(m.ctx)} - m.closeLocked() - } - return + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + errState = true + continue case req, ok := <-requests: if !ok { // Done send EOF @@ -371,19 +392,14 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { - m.respMu.Lock() - responses <- Response{Err: err} - m.closeLocked() - m.respMu.Unlock() + m.addErrorNonBlockingClose(internalResp, err) } return } - if errState { - continue - } // Grab a send token. select { case <-m.ctx.Done(): + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) errState = true continue case <-m.outBlock: @@ -398,8 +414,7 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha err := m.send(msg) PutByteBuffer(req) if err != nil { - responses <- Response{Err: err} - m.close() + m.addErrorNonBlockingClose(internalResp, err) errState = true continue } @@ -534,6 +549,7 @@ func (m *muxClient) closeLocked() { if m.closed { return } + // We hold the lock, so nobody can modify m.respWait while we're closing. if m.respWait != nil { xioutil.SafeClose(m.respWait) m.respWait = nil diff --git a/internal/grid/stream.go b/internal/grid/stream.go index d65313d77..a99b66643 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -41,7 +41,8 @@ type Stream struct { // Requests sent cannot be used any further by the called. Requests chan<- []byte - ctx context.Context + muxID uint64 + ctx context.Context } // Send a payload to the remote server.