From bed1691a93e82f96f4f88acf80a88ce874189c36 Mon Sep 17 00:00:00 2001 From: Ben McClelland Date: Mon, 4 Dec 2023 09:02:53 -0800 Subject: [PATCH] feat: implement logic for s3 select object content stream --- s3select/message-handler.go | 324 +++++++++++++++++++++++++++++++++++- 1 file changed, 319 insertions(+), 5 deletions(-) diff --git a/s3select/message-handler.go b/s3select/message-handler.go index ecde975..fcefc0b 100644 --- a/s3select/message-handler.go +++ b/s3select/message-handler.go @@ -17,28 +17,342 @@ package s3select import ( "bufio" "context" + "encoding/binary" + "encoding/xml" + "fmt" + "hash/crc32" + "sync" + "sync/atomic" + "time" ) +// Protocol definition for messages can be found here: +// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html + +var ( + // From ptotocol def: + // Enum indicating the header value type. + // For Amazon S3 Select, this is always 7. + headerValueType = byte(7) +) + +func intToTwoBytes(i int) []byte { + return []byte{byte(i >> 8), byte(i)} +} + +func generateHeader(messages ...string) []byte { + var header []byte + + for i, message := range messages { + if i%2 == 1 { + header = append(header, headerValueType) + header = append(header, intToTwoBytes(len(message))...) + } else { + header = append(header, byte(len(message))) + } + header = append(header, message...) + } + + return header +} + +func generateOctetHeader(message string) []byte { + return generateHeader( + ":message-type", + "event", + ":content-type", + "application/octet-stream", + ":event-type", + message) +} + +func generateTextHeader(message string) []byte { + return generateHeader( + ":message-type", + "event", + ":content-type", + "text/xml", + ":event-type", + message) +} + +func generateNoContentHeader(message string) []byte { + return generateHeader( + ":message-type", + "event", + ":event-type", + message) +} + +const ( + // 4 bytes total byte len + + // 4 bytes headers bytes len + + // 4 bytes prelude CRC + preludeLen = 12 + // CRC is uint32 + msgCrcLen = 4 +) + +var ( + recordsHeader = generateOctetHeader("Records") + continuationHeader = generateNoContentHeader("Cont") + continuationMessage = genMessage(continuationHeader, []byte{}) + progressHeader = generateTextHeader("Progress") + statsHeader = generateTextHeader("Stats") + endHeader = generateNoContentHeader("End") + endMessage = genMessage(endHeader, []byte{}) +) + +func uintToBytes(n uint32) []byte { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + return b +} + +func generatePrelude(msgLen int, headerLen int) []byte { + prelude := make([]byte, 0, preludeLen) + + // 4 bytes total byte len + prelude = append(prelude, uintToBytes(uint32(msgLen+headerLen+preludeLen+msgCrcLen))...) + // 4 bytes headers bytes len + prelude = append(prelude, uintToBytes(uint32(headerLen))...) + // 4 bytes prelude CRC + prelude = append(prelude, uintToBytes(crc32.ChecksumIEEE(prelude))...) + + return prelude +} + +const ( + maxHeaderSize = 1024 * 1024 + maxMessageSize = 5 * 1024 * 1024 * 1024 +) + +func genMessage(header, payload []byte) []byte { + var msg []byte + // below is always true since the size is validated + // in the send record + if len(header) <= maxHeaderSize && len(payload) <= maxMessageSize { + msglen := preludeLen + len(header) + len(payload) + msgCrcLen + msg = make([]byte, 0, msglen) + } + + msg = append(msg, generatePrelude(len(payload), len(header))...) + msg = append(msg, header...) + msg = append(msg, payload...) + msg = append(msg, uintToBytes(crc32.ChecksumIEEE(msg))...) + + return msg +} + +func genRecordsMessage(payload []byte) []byte { + return genMessage(recordsHeader, payload) +} + +type progress struct { + XMLName xml.Name `xml:"Progress"` + BytesScanned int64 `xml:"BytesScanned"` + BytesProcessed int64 `xml:"BytesProcessed"` + BytesReturned int64 `xml:"BytesReturned"` +} + +func genProgressMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte { + progress := progress{ + BytesScanned: bytesScanned, + BytesProcessed: bytesProcessed, + BytesReturned: bytesReturned, + } + + xmlData, _ := xml.MarshalIndent(progress, "", " ") + payload := []byte(xml.Header + string(xmlData)) + return genMessage(progressHeader, payload) +} + +type stats struct { + XMLName xml.Name `xml:"Stats"` + BytesScanned int64 `xml:"BytesScanned"` + BytesProcessed int64 `xml:"BytesProcessed"` + BytesReturned int64 `xml:"BytesReturned"` +} + +func genStatsMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte { + stats := stats{ + BytesScanned: bytesScanned, + BytesProcessed: bytesProcessed, + BytesReturned: bytesReturned, + } + + xmlData, _ := xml.MarshalIndent(stats, "", " ") + payload := []byte(xml.Header + string(xmlData)) + return genMessage(statsHeader, payload) +} + +func genErrorMessage(errorCode, errorMessage string) []byte { + return genMessage(generateHeader( + ":error-code", + errorCode, + ":error-message", + errorMessage, + ":message-type", + "error", + ), []byte{}) +} + +// GetProgress is a callback function that periodically retrieves the current +// values for the following if not nil. This is used to send Progress +// messages back to client. +// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed). +// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed). type GetProgress func() (bytesScanned int64, bytesProcessed int64) -type MessageHandler struct{} +type MessageHandler struct { + sync.Mutex + ctx context.Context + cancel context.CancelFunc + writer *bufio.Writer + data chan []byte + getProgress GetProgress + stopCh chan bool + resetCh chan bool + bytesReturned int64 +} -// Creates a new MessageHandler instance and starts the event streaming +// NewMessageHandler creates a new MessageHandler instance and starts the event streaming func NewMessageHandler(ctx context.Context, w *bufio.Writer, getProgressFunc GetProgress) *MessageHandler { - return &MessageHandler{} + ctx, cancel := context.WithCancel(ctx) + + mh := &MessageHandler{ + ctx: ctx, + cancel: cancel, + writer: w, + data: make(chan []byte), + getProgress: getProgressFunc, + resetCh: make(chan bool), + stopCh: make(chan bool), + } + + go mh.sendBackgroundMessages(mh.resetCh, mh.stopCh) + return mh +} + +func (mh *MessageHandler) write(data []byte) error { + mh.Lock() + defer mh.Unlock() + + mh.stopCh <- true + defer func() { mh.resetCh <- true }() + + _, err := mh.writer.Write(data) + if err != nil { + return err + } + + return mh.writer.Flush() +} + +const ( + continuationInterval = time.Second + progressInterval = time.Minute +) + +func (mh *MessageHandler) sendBackgroundMessages(resetCh, stopCh <-chan bool) { + continuationTicker := time.NewTicker(continuationInterval) + defer continuationTicker.Stop() + + var progressTicker *time.Ticker + var progressTickerChan <-chan time.Time + if mh.getProgress != nil { + progressTicker = time.NewTicker(progressInterval) + progressTickerChan = progressTicker.C + defer progressTicker.Stop() + } + +Loop: + for { + select { + case <-mh.ctx.Done(): + break Loop + + case <-continuationTicker.C: + err := mh.write(continuationMessage) + if err != nil { + mh.cancel() + break Loop + } + + case <-resetCh: + continuationTicker.Reset(continuationInterval) + + case <-stopCh: + continuationTicker.Stop() + + case <-progressTickerChan: + var bytesScanned, bytesProcessed int64 + if mh.getProgress != nil { + bytesScanned, bytesProcessed = mh.getProgress() + } + bytesReturned := atomic.LoadInt64(&mh.bytesReturned) + err := mh.write(genProgressMessage(bytesScanned, bytesProcessed, bytesReturned)) + if err != nil { + mh.cancel() + break Loop + } + } + } } // SendRecord sends a single Records message func (mh *MessageHandler) SendRecord(payload []byte) error { + if mh.ctx.Err() != nil { + return mh.ctx.Err() + } + + if len(payload) > maxMessageSize { + return fmt.Errorf("record max size exceeded") + } + + err := mh.write(genRecordsMessage(payload)) + if err != nil { + return err + } + + atomic.AddInt64(&mh.bytesReturned, int64(len(payload))) return nil } -// Finish terminates message stream with Stat and End message -func (mh *MessageHandler) Finish() error { +// Finish terminates message stream with Stats and End message +// generates stats and end message using function args based on: +// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed). +// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed). +func (mh *MessageHandler) Finish(bytesScanned, bytesProcessed int64) error { + if mh.ctx.Err() != nil { + return mh.ctx.Err() + } + + bytesReturned := atomic.LoadInt64(&mh.bytesReturned) + err := mh.write(genStatsMessage(bytesScanned, bytesProcessed, bytesReturned)) + if err != nil { + return err + } + + err = mh.write(endMessage) + if err != nil { + return err + } + + mh.cancel() return nil } // FinishWithError terminates event stream with error func (mh *MessageHandler) FinishWithError(errorCode, errorMessage string) error { + if mh.ctx.Err() != nil { + return mh.ctx.Err() + } + err := mh.write(genErrorMessage(errorCode, errorMessage)) + if err != nil { + return err + } + + mh.cancel() return nil }