mirror of
https://github.com/samuelncui/yatm.git
synced 2025-12-23 06:15:22 +00:00
117 lines
2.5 KiB
Go
117 lines
2.5 KiB
Go
package entity
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/modern-go/reflect2"
|
|
"github.com/samuelncui/yatm/tools"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
const (
|
|
compressThreshold = 1024
|
|
)
|
|
|
|
var (
|
|
magicHeaderV2 = []byte{0xff, 'y', 'm', '\x02'}
|
|
|
|
zstdEncoderPool = tools.NewPool(func() *zstd.Encoder {
|
|
encoder, _ := zstd.NewWriter(nil) // there will be no error without options
|
|
return encoder
|
|
})
|
|
zstdDecoderPool = tools.NewPool(func() *zstd.Decoder {
|
|
decoder, _ := zstd.NewReader(nil) // there will be no error without options
|
|
return decoder
|
|
})
|
|
)
|
|
|
|
// Scan implement database/sql.Scanner
|
|
func Scan(dst proto.Message, src interface{}) error {
|
|
typ := reflect2.TypeOf(dst).(reflect2.PtrType).Elem()
|
|
typ.Set(dst, typ.New())
|
|
|
|
var buf []byte
|
|
switch v := src.(type) {
|
|
case string:
|
|
buf = []byte(v)
|
|
case []byte:
|
|
buf = v
|
|
case nil:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("process define extra scanner, unexpected type for i18n, %T", v)
|
|
}
|
|
if len(buf) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if bytes.HasPrefix(buf, magicHeaderV2) {
|
|
decoder := zstdDecoderPool.Get()
|
|
|
|
err := decoder.Reset(bytes.NewBuffer(buf[len(magicHeaderV2):]))
|
|
if err != nil {
|
|
return fmt.Errorf("zstd reset decoder fail, %w", err)
|
|
}
|
|
|
|
buf, err = io.ReadAll(decoder)
|
|
if err != nil {
|
|
return fmt.Errorf("zstd read decoder fail, %w", err)
|
|
}
|
|
|
|
decoder.Reset(nil)
|
|
zstdDecoderPool.Put(decoder)
|
|
}
|
|
|
|
if err := proto.Unmarshal(buf, dst); err != nil {
|
|
return fmt.Errorf("process define extra scanner, protobuf unmarshal fail, %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Value implement database/sql/driver.Valuer
|
|
func Value(src proto.Message) (driver.Value, error) {
|
|
buf, err := proto.Marshal(src)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("process define extra valuer, protobuf marshal fail, %w", err)
|
|
}
|
|
|
|
if len(buf) <= compressThreshold {
|
|
return buf, nil
|
|
}
|
|
|
|
buffer := bytes.NewBuffer(make([]byte, 0, len(buf)))
|
|
buffer.Write(magicHeaderV2)
|
|
|
|
encoder := zstdEncoderPool.Get()
|
|
encoder.Reset(buffer)
|
|
_, err = encoder.Write(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("zstd write to encoder fail, %w", err)
|
|
}
|
|
err = encoder.Close()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("zstd close encoder fail, %w", err)
|
|
}
|
|
|
|
buf = buffer.Bytes()
|
|
encoder.Reset(nil)
|
|
zstdEncoderPool.Put(encoder)
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func ToEnum[T ~int32](pbmap map[string]int32, default_ T) func(str string) T {
|
|
return func(str string) T {
|
|
v, ok := pbmap[string(str)]
|
|
if !ok {
|
|
return default_
|
|
}
|
|
return T(v)
|
|
}
|
|
}
|