feat: implement logic for s3 select object content stream

This commit is contained in:
Ben McClelland
2023-12-04 09:02:53 -08:00
parent 48818927bb
commit bed1691a93

View File

@@ -17,28 +17,342 @@ package s3select
import ( import (
"bufio" "bufio"
"context" "context"
"encoding/binary"
"encoding/xml"
"fmt"
"hash/crc32"
"sync"
"sync/atomic"
"time"
) )
// Protocol definition for messages can be found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html
var (
// From ptotocol def:
// Enum indicating the header value type.
// For Amazon S3 Select, this is always 7.
headerValueType = byte(7)
)
func intToTwoBytes(i int) []byte {
return []byte{byte(i >> 8), byte(i)}
}
func generateHeader(messages ...string) []byte {
var header []byte
for i, message := range messages {
if i%2 == 1 {
header = append(header, headerValueType)
header = append(header, intToTwoBytes(len(message))...)
} else {
header = append(header, byte(len(message)))
}
header = append(header, message...)
}
return header
}
func generateOctetHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":content-type",
"application/octet-stream",
":event-type",
message)
}
func generateTextHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":content-type",
"text/xml",
":event-type",
message)
}
func generateNoContentHeader(message string) []byte {
return generateHeader(
":message-type",
"event",
":event-type",
message)
}
const (
// 4 bytes total byte len +
// 4 bytes headers bytes len +
// 4 bytes prelude CRC
preludeLen = 12
// CRC is uint32
msgCrcLen = 4
)
var (
recordsHeader = generateOctetHeader("Records")
continuationHeader = generateNoContentHeader("Cont")
continuationMessage = genMessage(continuationHeader, []byte{})
progressHeader = generateTextHeader("Progress")
statsHeader = generateTextHeader("Stats")
endHeader = generateNoContentHeader("End")
endMessage = genMessage(endHeader, []byte{})
)
func uintToBytes(n uint32) []byte {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
return b
}
func generatePrelude(msgLen int, headerLen int) []byte {
prelude := make([]byte, 0, preludeLen)
// 4 bytes total byte len
prelude = append(prelude, uintToBytes(uint32(msgLen+headerLen+preludeLen+msgCrcLen))...)
// 4 bytes headers bytes len
prelude = append(prelude, uintToBytes(uint32(headerLen))...)
// 4 bytes prelude CRC
prelude = append(prelude, uintToBytes(crc32.ChecksumIEEE(prelude))...)
return prelude
}
const (
maxHeaderSize = 1024 * 1024
maxMessageSize = 5 * 1024 * 1024 * 1024
)
func genMessage(header, payload []byte) []byte {
var msg []byte
// below is always true since the size is validated
// in the send record
if len(header) <= maxHeaderSize && len(payload) <= maxMessageSize {
msglen := preludeLen + len(header) + len(payload) + msgCrcLen
msg = make([]byte, 0, msglen)
}
msg = append(msg, generatePrelude(len(payload), len(header))...)
msg = append(msg, header...)
msg = append(msg, payload...)
msg = append(msg, uintToBytes(crc32.ChecksumIEEE(msg))...)
return msg
}
func genRecordsMessage(payload []byte) []byte {
return genMessage(recordsHeader, payload)
}
type progress struct {
XMLName xml.Name `xml:"Progress"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}
func genProgressMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
progress := progress{
BytesScanned: bytesScanned,
BytesProcessed: bytesProcessed,
BytesReturned: bytesReturned,
}
xmlData, _ := xml.MarshalIndent(progress, "", " ")
payload := []byte(xml.Header + string(xmlData))
return genMessage(progressHeader, payload)
}
type stats struct {
XMLName xml.Name `xml:"Stats"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}
func genStatsMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
stats := stats{
BytesScanned: bytesScanned,
BytesProcessed: bytesProcessed,
BytesReturned: bytesReturned,
}
xmlData, _ := xml.MarshalIndent(stats, "", " ")
payload := []byte(xml.Header + string(xmlData))
return genMessage(statsHeader, payload)
}
func genErrorMessage(errorCode, errorMessage string) []byte {
return genMessage(generateHeader(
":error-code",
errorCode,
":error-message",
errorMessage,
":message-type",
"error",
), []byte{})
}
// GetProgress is a callback function that periodically retrieves the current
// values for the following if not nil. This is used to send Progress
// messages back to client.
// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
type GetProgress func() (bytesScanned int64, bytesProcessed int64) type GetProgress func() (bytesScanned int64, bytesProcessed int64)
type MessageHandler struct{} type MessageHandler struct {
sync.Mutex
ctx context.Context
cancel context.CancelFunc
writer *bufio.Writer
data chan []byte
getProgress GetProgress
stopCh chan bool
resetCh chan bool
bytesReturned int64
}
// Creates a new MessageHandler instance and starts the event streaming // NewMessageHandler creates a new MessageHandler instance and starts the event streaming
func NewMessageHandler(ctx context.Context, w *bufio.Writer, getProgressFunc GetProgress) *MessageHandler { func NewMessageHandler(ctx context.Context, w *bufio.Writer, getProgressFunc GetProgress) *MessageHandler {
return &MessageHandler{} ctx, cancel := context.WithCancel(ctx)
mh := &MessageHandler{
ctx: ctx,
cancel: cancel,
writer: w,
data: make(chan []byte),
getProgress: getProgressFunc,
resetCh: make(chan bool),
stopCh: make(chan bool),
}
go mh.sendBackgroundMessages(mh.resetCh, mh.stopCh)
return mh
}
func (mh *MessageHandler) write(data []byte) error {
mh.Lock()
defer mh.Unlock()
mh.stopCh <- true
defer func() { mh.resetCh <- true }()
_, err := mh.writer.Write(data)
if err != nil {
return err
}
return mh.writer.Flush()
}
const (
continuationInterval = time.Second
progressInterval = time.Minute
)
func (mh *MessageHandler) sendBackgroundMessages(resetCh, stopCh <-chan bool) {
continuationTicker := time.NewTicker(continuationInterval)
defer continuationTicker.Stop()
var progressTicker *time.Ticker
var progressTickerChan <-chan time.Time
if mh.getProgress != nil {
progressTicker = time.NewTicker(progressInterval)
progressTickerChan = progressTicker.C
defer progressTicker.Stop()
}
Loop:
for {
select {
case <-mh.ctx.Done():
break Loop
case <-continuationTicker.C:
err := mh.write(continuationMessage)
if err != nil {
mh.cancel()
break Loop
}
case <-resetCh:
continuationTicker.Reset(continuationInterval)
case <-stopCh:
continuationTicker.Stop()
case <-progressTickerChan:
var bytesScanned, bytesProcessed int64
if mh.getProgress != nil {
bytesScanned, bytesProcessed = mh.getProgress()
}
bytesReturned := atomic.LoadInt64(&mh.bytesReturned)
err := mh.write(genProgressMessage(bytesScanned, bytesProcessed, bytesReturned))
if err != nil {
mh.cancel()
break Loop
}
}
}
} }
// SendRecord sends a single Records message // SendRecord sends a single Records message
func (mh *MessageHandler) SendRecord(payload []byte) error { func (mh *MessageHandler) SendRecord(payload []byte) error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}
if len(payload) > maxMessageSize {
return fmt.Errorf("record max size exceeded")
}
err := mh.write(genRecordsMessage(payload))
if err != nil {
return err
}
atomic.AddInt64(&mh.bytesReturned, int64(len(payload)))
return nil return nil
} }
// Finish terminates message stream with Stat and End message // Finish terminates message stream with Stats and End message
func (mh *MessageHandler) Finish() error { // generates stats and end message using function args based on:
// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
func (mh *MessageHandler) Finish(bytesScanned, bytesProcessed int64) error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}
bytesReturned := atomic.LoadInt64(&mh.bytesReturned)
err := mh.write(genStatsMessage(bytesScanned, bytesProcessed, bytesReturned))
if err != nil {
return err
}
err = mh.write(endMessage)
if err != nil {
return err
}
mh.cancel()
return nil return nil
} }
// FinishWithError terminates event stream with error // FinishWithError terminates event stream with error
func (mh *MessageHandler) FinishWithError(errorCode, errorMessage string) error { func (mh *MessageHandler) FinishWithError(errorCode, errorMessage string) error {
if mh.ctx.Err() != nil {
return mh.ctx.Err()
}
err := mh.write(genErrorMessage(errorCode, errorMessage))
if err != nil {
return err
}
mh.cancel()
return nil return nil
} }