mirror of
https://github.com/versity/versitygw.git
synced 2025-12-23 05:05:16 +00:00
358 lines
8.9 KiB
Go
358 lines
8.9 KiB
Go
// Copyright 2023 Versity Software
|
|
// This file is licensed under the Apache License, Version 2.0
|
|
// (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
package s3select
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Protocol definition for messages can be found here:
|
|
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html
|
|
|
|
var (
|
|
// From ptotocol def:
|
|
// Enum indicating the header value type.
|
|
// For Amazon S3 Select, this is always 7.
|
|
headerValueType = byte(7)
|
|
)
|
|
|
|
func intToTwoBytes(i int) []byte {
|
|
return []byte{byte(i >> 8), byte(i)}
|
|
}
|
|
|
|
func generateHeader(messages ...string) []byte {
|
|
var header []byte
|
|
|
|
for i, message := range messages {
|
|
if i%2 == 1 {
|
|
header = append(header, headerValueType)
|
|
header = append(header, intToTwoBytes(len(message))...)
|
|
} else {
|
|
header = append(header, byte(len(message)))
|
|
}
|
|
header = append(header, message...)
|
|
}
|
|
|
|
return header
|
|
}
|
|
|
|
func generateOctetHeader(message string) []byte {
|
|
return generateHeader(
|
|
":message-type",
|
|
"event",
|
|
":content-type",
|
|
"application/octet-stream",
|
|
":event-type",
|
|
message)
|
|
}
|
|
|
|
func generateTextHeader(message string) []byte {
|
|
return generateHeader(
|
|
":message-type",
|
|
"event",
|
|
":content-type",
|
|
"text/xml",
|
|
":event-type",
|
|
message)
|
|
}
|
|
|
|
func generateNoContentHeader(message string) []byte {
|
|
return generateHeader(
|
|
":message-type",
|
|
"event",
|
|
":event-type",
|
|
message)
|
|
}
|
|
|
|
const (
|
|
// 4 bytes total byte len +
|
|
// 4 bytes headers bytes len +
|
|
// 4 bytes prelude CRC
|
|
preludeLen = 12
|
|
// CRC is uint32
|
|
msgCrcLen = 4
|
|
)
|
|
|
|
var (
|
|
recordsHeader = generateOctetHeader("Records")
|
|
continuationHeader = generateNoContentHeader("Cont")
|
|
continuationMessage = genMessage(continuationHeader, []byte{})
|
|
progressHeader = generateTextHeader("Progress")
|
|
statsHeader = generateTextHeader("Stats")
|
|
endHeader = generateNoContentHeader("End")
|
|
endMessage = genMessage(endHeader, []byte{})
|
|
)
|
|
|
|
func uintToBytes(n uint32) []byte {
|
|
b := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(b, n)
|
|
return b
|
|
}
|
|
|
|
func generatePrelude(msgLen int, headerLen int) []byte {
|
|
prelude := make([]byte, 0, preludeLen)
|
|
|
|
// 4 bytes total byte len
|
|
prelude = append(prelude, uintToBytes(uint32(msgLen+headerLen+preludeLen+msgCrcLen))...)
|
|
// 4 bytes headers bytes len
|
|
prelude = append(prelude, uintToBytes(uint32(headerLen))...)
|
|
// 4 bytes prelude CRC
|
|
prelude = append(prelude, uintToBytes(crc32.ChecksumIEEE(prelude))...)
|
|
|
|
return prelude
|
|
}
|
|
|
|
const (
|
|
maxHeaderSize = 1024 * 1024
|
|
)
|
|
|
|
func genMessage(header, payload []byte) []byte {
|
|
var msg []byte
|
|
// below is always true since the size is validated
|
|
// in the send record
|
|
if len(header) <= maxHeaderSize && len(payload) <= maxMessageSize {
|
|
msglen := preludeLen + len(header) + len(payload) + msgCrcLen
|
|
msg = make([]byte, 0, msglen)
|
|
}
|
|
|
|
msg = append(msg, generatePrelude(len(payload), len(header))...)
|
|
msg = append(msg, header...)
|
|
msg = append(msg, payload...)
|
|
msg = append(msg, uintToBytes(crc32.ChecksumIEEE(msg))...)
|
|
|
|
return msg
|
|
}
|
|
|
|
func genRecordsMessage(payload []byte) []byte {
|
|
return genMessage(recordsHeader, payload)
|
|
}
|
|
|
|
type progress struct {
|
|
XMLName xml.Name `xml:"Progress"`
|
|
BytesScanned int64 `xml:"BytesScanned"`
|
|
BytesProcessed int64 `xml:"BytesProcessed"`
|
|
BytesReturned int64 `xml:"BytesReturned"`
|
|
}
|
|
|
|
func genProgressMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
|
|
progress := progress{
|
|
BytesScanned: bytesScanned,
|
|
BytesProcessed: bytesProcessed,
|
|
BytesReturned: bytesReturned,
|
|
}
|
|
|
|
xmlData, _ := xml.MarshalIndent(progress, "", " ")
|
|
payload := []byte(xml.Header + string(xmlData))
|
|
return genMessage(progressHeader, payload)
|
|
}
|
|
|
|
type stats struct {
|
|
XMLName xml.Name `xml:"Stats"`
|
|
BytesScanned int64 `xml:"BytesScanned"`
|
|
BytesProcessed int64 `xml:"BytesProcessed"`
|
|
BytesReturned int64 `xml:"BytesReturned"`
|
|
}
|
|
|
|
func genStatsMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
|
|
stats := stats{
|
|
BytesScanned: bytesScanned,
|
|
BytesProcessed: bytesProcessed,
|
|
BytesReturned: bytesReturned,
|
|
}
|
|
|
|
xmlData, _ := xml.MarshalIndent(stats, "", " ")
|
|
payload := []byte(xml.Header + string(xmlData))
|
|
return genMessage(statsHeader, payload)
|
|
}
|
|
|
|
func genErrorMessage(errorCode, errorMessage string) []byte {
|
|
return genMessage(generateHeader(
|
|
":error-code",
|
|
errorCode,
|
|
":error-message",
|
|
errorMessage,
|
|
":message-type",
|
|
"error",
|
|
), []byte{})
|
|
}
|
|
|
|
// GetProgress is a callback function that periodically retrieves the current
|
|
// values for the following if not nil. This is used to send Progress
|
|
// messages back to client.
|
|
// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
|
|
// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
|
|
type GetProgress func() (bytesScanned int64, bytesProcessed int64)
|
|
|
|
type MessageHandler struct {
|
|
sync.Mutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
writer *bufio.Writer
|
|
data chan []byte
|
|
getProgress GetProgress
|
|
stopCh chan bool
|
|
resetCh chan bool
|
|
bytesReturned int64
|
|
}
|
|
|
|
// NewMessageHandler creates a new MessageHandler instance and starts the event streaming
|
|
func NewMessageHandler(ctx context.Context, w *bufio.Writer, getProgressFunc GetProgress) *MessageHandler {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
mh := &MessageHandler{
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
writer: w,
|
|
data: make(chan []byte),
|
|
getProgress: getProgressFunc,
|
|
resetCh: make(chan bool),
|
|
stopCh: make(chan bool),
|
|
}
|
|
|
|
go mh.sendBackgroundMessages(mh.resetCh, mh.stopCh)
|
|
return mh
|
|
}
|
|
|
|
func (mh *MessageHandler) write(data []byte) error {
|
|
mh.Lock()
|
|
defer mh.Unlock()
|
|
|
|
mh.stopCh <- true
|
|
defer func() { mh.resetCh <- true }()
|
|
|
|
_, err := mh.writer.Write(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return mh.writer.Flush()
|
|
}
|
|
|
|
const (
|
|
continuationInterval = time.Second
|
|
progressInterval = time.Minute
|
|
)
|
|
|
|
func (mh *MessageHandler) sendBackgroundMessages(resetCh, stopCh <-chan bool) {
|
|
continuationTicker := time.NewTicker(continuationInterval)
|
|
defer continuationTicker.Stop()
|
|
|
|
var progressTicker *time.Ticker
|
|
var progressTickerChan <-chan time.Time
|
|
if mh.getProgress != nil {
|
|
progressTicker = time.NewTicker(progressInterval)
|
|
progressTickerChan = progressTicker.C
|
|
defer progressTicker.Stop()
|
|
}
|
|
|
|
Loop:
|
|
for {
|
|
select {
|
|
case <-mh.ctx.Done():
|
|
break Loop
|
|
|
|
case <-continuationTicker.C:
|
|
err := mh.write(continuationMessage)
|
|
if err != nil {
|
|
mh.cancel()
|
|
break Loop
|
|
}
|
|
|
|
case <-resetCh:
|
|
continuationTicker.Reset(continuationInterval)
|
|
|
|
case <-stopCh:
|
|
continuationTicker.Stop()
|
|
|
|
case <-progressTickerChan:
|
|
var bytesScanned, bytesProcessed int64
|
|
if mh.getProgress != nil {
|
|
bytesScanned, bytesProcessed = mh.getProgress()
|
|
}
|
|
bytesReturned := atomic.LoadInt64(&mh.bytesReturned)
|
|
err := mh.write(genProgressMessage(bytesScanned, bytesProcessed, bytesReturned))
|
|
if err != nil {
|
|
mh.cancel()
|
|
break Loop
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendRecord sends a single Records message
|
|
func (mh *MessageHandler) SendRecord(payload []byte) error {
|
|
if mh.ctx.Err() != nil {
|
|
return mh.ctx.Err()
|
|
}
|
|
|
|
if len(payload) > maxMessageSize {
|
|
return fmt.Errorf("record max size exceeded")
|
|
}
|
|
|
|
err := mh.write(genRecordsMessage(payload))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
atomic.AddInt64(&mh.bytesReturned, int64(len(payload)))
|
|
return nil
|
|
}
|
|
|
|
// Finish terminates message stream with Stats and End message
|
|
// generates stats and end message using function args based on:
|
|
// BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
|
|
// BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
|
|
func (mh *MessageHandler) Finish(bytesScanned, bytesProcessed int64) error {
|
|
if mh.ctx.Err() != nil {
|
|
return mh.ctx.Err()
|
|
}
|
|
|
|
bytesReturned := atomic.LoadInt64(&mh.bytesReturned)
|
|
err := mh.write(genStatsMessage(bytesScanned, bytesProcessed, bytesReturned))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = mh.write(endMessage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
mh.cancel()
|
|
return nil
|
|
}
|
|
|
|
// FinishWithError terminates event stream with error
|
|
func (mh *MessageHandler) FinishWithError(errorCode, errorMessage string) error {
|
|
if mh.ctx.Err() != nil {
|
|
return mh.ctx.Err()
|
|
}
|
|
err := mh.write(genErrorMessage(errorCode, errorMessage))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
mh.cancel()
|
|
return nil
|
|
}
|