Files
versitygw/s3api/utils/utils.go

431 lines
10 KiB
Go

// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package utils
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go/encoding/httpbinding"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
var (
bucketNameRegexp = regexp.MustCompile(`^[a-z0-9][a-z0-9.-]+[a-z0-9]$`)
bucketNameIpRegexp = regexp.MustCompile(`^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$`)
)
const (
upperhex = "0123456789ABCDEF"
)
func GetUserMetaData(headers *fasthttp.RequestHeader) (metadata map[string]string) {
metadata = make(map[string]string)
headers.DisableNormalizing()
headers.VisitAllInOrder(func(key, value []byte) {
hKey := string(key)
if strings.HasPrefix(strings.ToLower(hKey), "x-amz-meta-") {
trimmedKey := hKey[11:]
headerValue := string(value)
metadata[trimmedKey] = headerValue
}
})
headers.EnableNormalizing()
return
}
func createHttpRequestFromCtx(ctx *fiber.Ctx, signedHdrs []string, contentLength int64) (*http.Request, error) {
req := ctx.Request()
var body io.Reader
if IsBigDataAction(ctx) {
body = req.BodyStream()
} else {
body = bytes.NewReader(req.Body())
}
escapedURI := escapeOriginalURI(ctx)
httpReq, err := http.NewRequest(string(req.Header.Method()), escapedURI, body)
if err != nil {
return nil, errors.New("error in creating an http request")
}
// Set the request headers
req.Header.VisitAll(func(key, value []byte) {
keyStr := string(key)
if includeHeader(keyStr, signedHdrs) {
httpReq.Header.Add(keyStr, string(value))
}
})
// make sure all headers in the signed headers are present
for _, header := range signedHdrs {
if httpReq.Header.Get(header) == "" {
httpReq.Header.Set(header, "")
}
}
// Check if Content-Length in signed headers
// If content length is non 0, then the header will be included
if !includeHeader("Content-Length", signedHdrs) {
httpReq.ContentLength = 0
} else {
httpReq.ContentLength = contentLength
}
// Set the Host header
httpReq.Host = string(req.Header.Host())
return httpReq, nil
}
var (
signedQueryArgs = map[string]bool{
"X-Amz-Algorithm": true,
"X-Amz-Credential": true,
"X-Amz-Date": true,
"X-Amz-SignedHeaders": true,
"X-Amz-Signature": true,
}
)
func createPresignedHttpRequestFromCtx(ctx *fiber.Ctx, signedHdrs []string, contentLength int64) (*http.Request, error) {
req := ctx.Request()
var body io.Reader
if IsBigDataAction(ctx) {
body = req.BodyStream()
} else {
body = bytes.NewReader(req.Body())
}
uri := string(ctx.Request().URI().Path())
uri = httpbinding.EscapePath(uri, false)
isFirst := true
ctx.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
_, ok := signedQueryArgs[string(key)]
if !ok {
escapeValue := url.QueryEscape(string(value))
if isFirst {
uri += fmt.Sprintf("?%s=%s", key, escapeValue)
isFirst = false
} else {
uri += fmt.Sprintf("&%s=%s", key, escapeValue)
}
}
})
httpReq, err := http.NewRequest(string(req.Header.Method()), uri, body)
if err != nil {
return nil, errors.New("error in creating an http request")
}
// Set the request headers
req.Header.VisitAll(func(key, value []byte) {
keyStr := string(key)
if includeHeader(keyStr, signedHdrs) {
httpReq.Header.Add(keyStr, string(value))
}
})
// Check if Content-Length in signed headers
// If content length is non 0, then the header will be included
if !includeHeader("Content-Length", signedHdrs) {
httpReq.ContentLength = 0
} else {
httpReq.ContentLength = contentLength
}
// Set the Host header
httpReq.Host = string(req.Header.Host())
return httpReq, nil
}
func SetMetaHeaders(ctx *fiber.Ctx, meta map[string]string) {
ctx.Response().Header.DisableNormalizing()
for key, val := range meta {
ctx.Response().Header.Set(fmt.Sprintf("X-Amz-Meta-%s", key), val)
}
ctx.Response().Header.EnableNormalizing()
}
func ParseUint(str string) (int32, error) {
if str == "" {
return 1000, nil
}
num, err := strconv.ParseUint(str, 10, 16)
if err != nil {
return 1000, fmt.Errorf("invalid uint: %w", err)
}
return int32(num), nil
}
type CustomHeader struct {
Key string
Value string
}
func SetResponseHeaders(ctx *fiber.Ctx, headers []CustomHeader) {
for _, header := range headers {
ctx.Set(header.Key, header.Value)
}
}
func IsValidBucketName(bucket string) bool {
if len(bucket) < 3 || len(bucket) > 63 {
return false
}
// Checks to contain only digits, lowercase letters, dot, hyphen.
// Checks to start and end with only digits and lowercase letters.
if !bucketNameRegexp.MatchString(bucket) {
return false
}
// Checks not to be a valid IP address
if bucketNameIpRegexp.MatchString(bucket) {
return false
}
return true
}
func includeHeader(hdr string, signedHdrs []string) bool {
for _, shdr := range signedHdrs {
if strings.EqualFold(hdr, shdr) {
return true
}
}
return false
}
func IsBigDataAction(ctx *fiber.Ctx) bool {
if ctx.Method() == http.MethodPut && len(strings.Split(ctx.Path(), "/")) >= 3 {
if !ctx.Request().URI().QueryArgs().Has("tagging") && ctx.Get("X-Amz-Copy-Source") == "" && !ctx.Request().URI().QueryArgs().Has("acl") {
return true
}
}
return false
}
// expiration time window
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html#RESTAuthenticationTimeStamp
const timeExpirationSec = 15 * 60
func ValidateDate(date time.Time) error {
now := time.Now().UTC()
diff := date.Unix() - now.Unix()
// Checks the dates difference to be within allotted window
if diff > timeExpirationSec || diff < -timeExpirationSec {
return s3err.GetAPIError(s3err.ErrRequestTimeTooSkewed)
}
return nil
}
func ParseDeleteObjects(objs []types.ObjectIdentifier) (result []string) {
for _, obj := range objs {
result = append(result, *obj.Key)
}
return
}
func FilterObjectAttributes(attrs map[s3response.ObjectAttributes]struct{}, output s3response.GetObjectAttributesResponse) s3response.GetObjectAttributesResponse {
// These properties shouldn't appear in the final response body
output.LastModified = nil
output.VersionId = nil
output.DeleteMarker = nil
if _, ok := attrs[s3response.ObjectAttributesEtag]; !ok {
output.ETag = nil
}
if _, ok := attrs[s3response.ObjectAttributesObjectParts]; !ok {
output.ObjectParts = nil
}
if _, ok := attrs[s3response.ObjectAttributesObjectSize]; !ok {
output.ObjectSize = nil
}
if _, ok := attrs[s3response.ObjectAttributesStorageClass]; !ok {
output.StorageClass = ""
}
fmt.Printf("%+v\n", output)
return output
}
func ParseObjectAttributes(ctx *fiber.Ctx) (map[s3response.ObjectAttributes]struct{}, error) {
attrs := map[s3response.ObjectAttributes]struct{}{}
var err error
ctx.Request().Header.VisitAll(func(key, value []byte) {
if string(key) == "X-Amz-Object-Attributes" {
oattrs := strings.Split(string(value), ",")
for _, a := range oattrs {
attr := s3response.ObjectAttributes(a)
if !attr.IsValid() {
err = s3err.GetAPIError(s3err.ErrInvalidObjectAttributes)
break
}
attrs[attr] = struct{}{}
}
}
})
return attrs, err
}
type objLockCfg struct {
RetainUntilDate time.Time
ObjectLockMode types.ObjectLockMode
LegalHoldStatus types.ObjectLockLegalHoldStatus
}
func ParsObjectLockHdrs(ctx *fiber.Ctx) (*objLockCfg, error) {
legalHoldHdr := ctx.Get("X-Amz-Object-Lock-Legal-Hold")
objLockModeHdr := ctx.Get("X-Amz-Object-Lock-Mode")
objLockDate := ctx.Get("X-Amz-Object-Lock-Retain-Until-Date")
if (objLockDate != "" && objLockModeHdr == "") || (objLockDate == "" && objLockModeHdr != "") {
return nil, s3err.GetAPIError(s3err.ErrObjectLockInvalidHeaders)
}
var retainUntilDate time.Time
if objLockDate != "" {
rDate, err := time.Parse(time.RFC3339, objLockDate)
if err != nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
if rDate.Before(time.Now()) {
return nil, s3err.GetAPIError(s3err.ErrPastObjectLockRetainDate)
}
retainUntilDate = rDate
}
objLockMode := types.ObjectLockMode(objLockModeHdr)
if objLockMode != "" &&
objLockMode != types.ObjectLockModeCompliance &&
objLockMode != types.ObjectLockModeGovernance {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
legalHold := types.ObjectLockLegalHoldStatus(legalHoldHdr)
if legalHold != "" && legalHold != types.ObjectLockLegalHoldStatusOff && legalHold != types.ObjectLockLegalHoldStatusOn {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
return &objLockCfg{
RetainUntilDate: retainUntilDate,
ObjectLockMode: objLockMode,
LegalHoldStatus: legalHold,
}, nil
}
func IsValidOwnership(val types.ObjectOwnership) bool {
switch val {
case types.ObjectOwnershipBucketOwnerEnforced:
return true
case types.ObjectOwnershipBucketOwnerPreferred:
return true
case types.ObjectOwnershipObjectWriter:
return true
default:
return false
}
}
func escapeOriginalURI(ctx *fiber.Ctx) string {
path := ctx.Path()
// Escape the URI original path
escapedURI := escapePath(path)
// Add the URI query params
query := string(ctx.Request().URI().QueryArgs().QueryString())
if query != "" {
escapedURI = escapedURI + "?" + query
}
return escapedURI
}
// Escapes the path string
// Most of the parts copied from std url
func escapePath(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
if hexCount == 0 {
return s
}
var buf [64]byte
var t []byte
required := len(s) + 2*hexCount
if required <= len(buf) {
t = buf[:required]
} else {
t = make([]byte, required)
}
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case shouldEscape(c):
t[j] = '%'
t[j+1] = upperhex[c>>4]
t[j+2] = upperhex[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}
// Checks if the character needs to be escaped
func shouldEscape(c byte) bool {
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '_', '.', '~', '/':
return false
}
return true
}