From 5c32058ff34cc3c968adf956f0fe08ec627101d4 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 13 Mar 2024 19:43:58 +0100 Subject: [PATCH] cosmetic: Move request goroutines to methods (#19241) Cosmetic change, but breaks up a big code block and will make a goroutine dumps of streams are more readable, so it is clearer what each goroutine is doing. --- internal/grid/muxserver.go | 184 +++++++++++++++++++++---------------- 1 file changed, 105 insertions(+), 79 deletions(-) diff --git a/internal/grid/muxserver.go b/internal/grid/muxserver.go index f4917bafa..2bec67e17 100644 --- a/internal/grid/muxserver.go +++ b/internal/grid/muxserver.go @@ -21,7 +21,6 @@ import ( "context" "errors" "fmt" - "runtime/debug" "sync" "sync/atomic" "time" @@ -123,109 +122,136 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea if inboundCap > 0 { m.inbound = make(chan []byte, inboundCap) handlerIn = make(chan []byte, 1) - go func(inbound <-chan []byte) { + go func(inbound chan []byte) { wg.Wait() defer xioutil.SafeClose(handlerIn) - // Send unblocks when we have delivered the message to the handler. - for in := range inbound { - handlerIn <- in - m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) - } + m.handleInbound(c, inbound, handlerIn) }(m.inbound) } + // Fill outbound block. + // Each token represents a message that can be sent to the client without blocking. + // The client will refill the tokens as they confirm delivery of the messages. for i := 0; i < outboundCap; i++ { m.outBlock <- struct{}{} } // Handler goroutine. - var handlerErr *RemoteErr + var handlerErr atomic.Pointer[RemoteErr] go func() { wg.Wait() - start := time.Now() - defer func() { - if debugPrint { - fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond)) - } - if r := recover(); r != nil { - logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) - debug.PrintStack() - err := RemoteErr(fmt.Sprintf("remote call panic: %v", r)) - handlerErr = &err - } - if debugPrint { - fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr) - } - xioutil.SafeClose(send) - }() - // handlerErr is guarded by 'send' channel. - handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send) + defer xioutil.SafeClose(send) + err := m.handleRequests(ctx, msg, send, handler, handlerIn) + if err != nil { + handlerErr.Store(err) + } }() - // Response sender gorutine... + + // Response sender goroutine... go func(outBlock <-chan struct{}) { wg.Wait() defer m.parent.deleteMux(true, m.ID) - for { - // Process outgoing message. - var payload []byte - var ok bool - select { - case payload, ok = <-send: - case <-ctx.Done(): - return - } - select { - case <-ctx.Done(): - return - case <-outBlock: - } - msg := message{ - MuxID: m.ID, - Op: OpMuxServerMsg, - Flags: c.baseFlags, - } - if !ok { - if debugPrint { - fmt.Println("muxServer: Mux", m.ID, "send EOF", handlerErr) - } - msg.Flags |= FlagEOF - if handlerErr != nil { - msg.Flags |= FlagPayloadIsErr - msg.Payload = []byte(*handlerErr) - } - msg.setZeroPayloadFlag() - m.send(msg) - return - } - msg.Payload = payload - msg.setZeroPayloadFlag() - m.send(msg) - } + m.sendResponses(ctx, send, c, &handlerErr, outBlock) }(m.outBlock) - // Remote aliveness check. + // Remote aliveness check if needed. if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { go func() { wg.Wait() - t := time.NewTicker(lastPingThreshold / 4) - defer t.Stop() - for { - select { - case <-m.ctx.Done(): - return - case <-t.C: - last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) - if last > lastPingThreshold { - logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) - m.close() - return - } - } - } + m.checkRemoteAlive() }() } return &m } +// handleInbound sends unblocks when we have delivered the message to the handler. +func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) { + for in := range inbound { + handlerIn <- in + m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) + } +} + +// sendResponses will send responses to the client. +func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) { + for { + // Process outgoing message. + var payload []byte + var ok bool + select { + case payload, ok = <-toSend: + case <-ctx.Done(): + return + } + select { + case <-ctx.Done(): + return + case <-outBlock: + } + msg := message{ + MuxID: m.ID, + Op: OpMuxServerMsg, + Flags: c.baseFlags, + } + if !ok { + hErr := handlerErr.Load() + if debugPrint { + fmt.Println("muxServer: Mux", m.ID, "send EOF", hErr) + } + msg.Flags |= FlagEOF + if hErr != nil { + msg.Flags |= FlagPayloadIsErr + msg.Payload = []byte(*hErr) + } + msg.setZeroPayloadFlag() + m.send(msg) + return + } + msg.Payload = payload + msg.setZeroPayloadFlag() + m.send(msg) + } +} + +// handleRequests will handle the requests from the client and call the handler function. +func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- []byte, handler StreamHandler, handlerIn <-chan []byte) (handlerErr *RemoteErr) { + start := time.Now() + defer func() { + if debugPrint { + fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond)) + } + if r := recover(); r != nil { + logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) + err := RemoteErr(fmt.Sprintf("handler panic: %v", r)) + handlerErr = &err + } + if debugPrint { + fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr) + } + }() + // handlerErr is guarded by 'send' channel. + handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send) + return handlerErr +} + +// checkRemoteAlive will check if the remote is alive. +func (m *muxServer) checkRemoteAlive() { + t := time.NewTicker(lastPingThreshold / 4) + defer t.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-t.C: + last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) + if last > lastPingThreshold { + logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) + m.close() + return + } + } + } +} + // checkSeq will check if sequence number is correct and increment it by 1. func (m *muxServer) checkSeq(seq uint32) (ok bool) { if seq != m.RecvSeq {