Files
versitygw/tests/integration/utils.go
niksis02 9f786b3c2c feat: global error refactoring
Fixes #2123
Fixes #2120
Fixes #2116
Fixes #2111
Fixes #2108
Fixes #2086
Fixes #2085
Fixes #2083
Fixes #2081
Fixes #2080
Fixes #2073
Fixes #2072
Fixes #2071
Fixes #2069
Fixes #2044
Fixes #2043
Fixes #2042
Fixes #2041
Fixes #2040
Fixes #2039
Fixes #2036
Fixes #2035
Fixes #2034
Fixes #2028
Fixes #2020
Fixes #1842
Fixes #1810
Fixes #1780
Fixes #1775
Fixes #1736
Fixes #1705
Fixes #1663
Fixes #1645
Fixes #1583
Fixes #1526
Fixes #1514
Fixes #1493
Fixes #1487
Fixes #959
Fixes #779
Closes #823
Closes #85

Refactor global S3 error handling around structured error types and centralized XML response generation.

All S3 errors now share the common APIError base for the fields every error has: Code, HTTP status code, and Message. Non-traditional errors that need AWS-compatible XML fields now have dedicated typed errors in the s3err package. Each typed error implements the shared S3Error behavior so controllers and middleware can handle errors consistently while still emitting error-specific XML fields.

Add a dedicated InvalidArgumentError type because InvalidArgument is used widely across request validation, auth, copy source handling, object lock validation, multipart validation, and header parsing. The new InvalidArgument path uses explicit InvalidArgErrorCode constants with predefined descriptions and ArgumentName values, keeping call sites readable while preserving the correct InvalidArgument XML shape and optional ArgumentValue.

New structured errors added in s3err:
- `AccessForbiddenError`: Method, ResourceType
- `BadDigestError`: CalculatedDigest, ExpectedDigest
- `BucketError`: BucketName
- `ContentSHA256MismatchError`: ClientComputedContentSHA256, S3ComputedContentSHA256
- `EntityTooLargeError`: ProposedSize, MaxSizeAllowed
- `EntityTooSmallError`: ProposedSize, MinSizeAllowed
- `ExpiredPresignedURLError`: ServerTime, XAmzExpires, Expires
- `InvalidAccessKeyIdError`: AWSAccessKeyId
- `InvalidArgumentError`: Description, ArgumentName, ArgumentValue
- `InvalidChunkSizeError`: Chunk, BadChunkSize
- `InvalidDigestError`: ContentMD5
- `InvalidLocationConstraintError`: LocationConstraint
- `InvalidPartError`: UploadId, PartNumber, ETag
- `InvalidRangeError`: RangeRequested, ActualObjectSize
- `InvalidTagError`: TagKey, TagValue
- `KeyTooLongError`: Size, MaxSizeAllowed
- `MetadataTooLargeError`: Size, MaxSizeAllowed
- `MethodNotAllowedError`: Method, ResourceType, AllowedMethods
- `NoSuchUploadError`: UploadId
- `NoSuchVersionError`: Key, VersionId
- `NotImplementedError`: Header, AdditionalMessage
- `PreconditionFailedError`: Condition
- `RequestTimeTooSkewedError`: RequestTime, ServerTime, MaxAllowedSkewMilliseconds
- `SignatureDoesNotMatchError`: AWSAccessKeyId, StringToSign, SignatureProvided, StringToSignBytes, CanonicalRequest, CanonicalRequestBytes

Fix CompleteMultipartUpload validation in the Azure backend so missing or empty `ETag` values return the appropriate S3 error instead of allowing a gateway panic.

Fix presigned authentication expiration validation to compare server time in `UTC`, matching the `UTC` timestamp used by presigned URL signing.

Add request ID and host ID support across S3 requests. Each request now receives AWS S3-like identifiers, returned in response headers as `x-amz-request-id` and `x-amz-id-2` and included in all XML error responses as RequestId and HostId. The generated ID structure is designed to resemble AWS S3 request IDs and host IDs.

The request signature calculation/validation for streaming uploads was previously delayed until the request body was fully read, both for Authorization header authentication and presigned URLs.
Now, the signature is validated immediately in the authorization middlewares without reading the request body, since the signature calculation itself does not depend on the request body. Instead, only the `x-amz-content-sha256` SHA-256 hash calculation is delayed.
2026-05-21 23:49:34 +04:00

3376 lines
97 KiB
Go

// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package integration
import (
"bytes"
"context"
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"hash"
"hash/crc32"
"hash/crc64"
"io"
"math/big"
"math/bits"
rnd "math/rand"
"mime/multipart"
"net/http"
"net/url"
"os/exec"
"slices"
"sort"
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/cespare/xxhash/v2"
"github.com/versity/versitygw/s3err"
"github.com/zeebo/xxh3"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
const emptySHA256Hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
var (
bcktCount atomic.Uint64
adminErrorPrefix = "XAdmin"
)
type user struct {
access string
secret string
role string
}
func getBucketName() string {
val := bcktCount.Add(1)
return fmt.Sprintf("test-bucket-%v", val)
}
func getUser(role string) user {
return user{
access: fmt.Sprintf("test-user-%v", genRandString(16)),
secret: fmt.Sprintf("test-secret-%v", genRandString(16)),
role: role,
}
}
func setup(s *S3Conf, bucket string, opts ...setupOpt) error {
s3client := s.GetClient()
cfg := new(setupCfg)
for _, opt := range opts {
opt(cfg)
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: &bucket,
ObjectLockEnabledForBucket: &cfg.LockEnabled,
ObjectOwnership: cfg.Ownership,
})
cancel()
if err != nil {
return err
}
if cfg.VersioningStatus != "" {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{
Bucket: &bucket,
VersioningConfiguration: &types.VersioningConfiguration{
Status: cfg.VersioningStatus,
},
})
cancel()
if err != nil {
return err
}
}
return nil
}
func teardown(s *S3Conf, bucket string) error {
s3client := s.GetClient()
deleteObject := func(bucket, key, versionId *string) error {
var attempts int
var err error
for attempts < maxRetryAttempts {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket,
Key: key,
VersionId: versionId,
BypassGovernanceRetention: aws.Bool(true),
})
cancel()
if err == nil {
return nil
}
attempts++
time.Sleep(time.Second)
}
return fmt.Errorf("delete object %s: %w", *key, err)
}
if s.versioningEnabled {
in := &s3.ListObjectVersionsInput{Bucket: &bucket}
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.ListObjectVersions(ctx, in)
cancel()
if err != nil {
return fmt.Errorf("failed to list objects: %w", err)
}
for _, item := range out.Versions {
err = deleteObject(&bucket, item.Key, item.VersionId)
if err != nil {
return err
}
}
for _, item := range out.DeleteMarkers {
err = deleteObject(&bucket, item.Key, item.VersionId)
if err != nil {
return err
}
}
if out.IsTruncated != nil && *out.IsTruncated {
in.KeyMarker = out.KeyMarker
in.VersionIdMarker = out.NextVersionIdMarker
} else {
break
}
}
} else {
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &bucket,
})
cancel()
if err != nil {
return fmt.Errorf("failed to list objects: %w", err)
}
for _, item := range out.Contents {
err = deleteObject(&bucket, item.Key, nil)
if err != nil {
return err
}
}
if out.IsTruncated != nil && *out.IsTruncated {
continue
}
break
}
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.DeleteBucket(ctx, &s3.DeleteBucketInput{
Bucket: &bucket,
})
cancel()
return err
}
type setupCfg struct {
LockEnabled bool
VersioningStatus types.BucketVersioningStatus
Ownership types.ObjectOwnership
Anonymous bool
SkipTearDown bool
}
type setupOpt func(*setupCfg)
func withLock() setupOpt {
return func(s *setupCfg) { s.LockEnabled = true }
}
func withOwnership(o types.ObjectOwnership) setupOpt {
return func(s *setupCfg) { s.Ownership = o }
}
func withVersioning(v types.BucketVersioningStatus) setupOpt {
return func(s *setupCfg) { s.VersioningStatus = v }
}
func withAnonymousClient() setupOpt {
return func(s *setupCfg) { s.Anonymous = true }
}
func withSkipTearDown() setupOpt {
return func(s *setupCfg) { s.SkipTearDown = true }
}
func actionHandler(s *S3Conf, testName string, handler func(s3client *s3.Client, bucket string) error, opts ...setupOpt) error {
runF(testName)
bucketName := getBucketName()
cfg := new(setupCfg)
for _, opt := range opts {
opt(cfg)
}
err := setup(s, bucketName, opts...)
if err != nil {
failF("%v: failed to create a bucket: %v", testName, err)
return fmt.Errorf("%v: failed to create a bucket: %w", testName, err)
}
var client *s3.Client
if cfg.Anonymous {
client = s.GetAnonymousClient()
} else {
client = s.GetClient()
}
handlerErr := handler(client, bucketName)
if handlerErr != nil {
failF("%v: %v", testName, handlerErr)
}
if !cfg.SkipTearDown {
err = teardown(s, bucketName)
if err != nil {
fmt.Printf(colorRed+"%v: failed to delete the bucket: %v", testName, err)
if handlerErr == nil {
return fmt.Errorf("%v: failed to delete the bucket: %w", testName, err)
}
}
}
if handlerErr == nil {
passF(testName)
}
return handlerErr
}
func actionHandlerNoSetup(s *S3Conf, testName string, handler func(s3client *s3.Client, bucket string) error, _ ...setupOpt) error {
runF(testName)
client := s.GetClient()
bucket := getBucketName()
handlerErr := handler(client, bucket)
if handlerErr != nil {
failF("%v: %v", testName, handlerErr)
}
if handlerErr == nil {
passF(testName)
}
return handlerErr
}
type authConfig struct {
testName string
path string
method string
overrideSha256 string
body []byte
service string
date time.Time
headers map[string]string
}
func authHandler(s *S3Conf, cfg *authConfig, handler func(req *http.Request) error) error {
runF(cfg.testName)
req, err := createSignedReq(cfg.method, s.endpoint, cfg.path, s.awsID, s.awsSecret, cfg.service, s.awsRegion, cfg.overrideSha256, cfg.body, cfg.date, cfg.headers)
if err != nil {
failF("%v: %v", cfg.testName, err)
return fmt.Errorf("%v: %w", cfg.testName, err)
}
err = handler(req)
if err != nil {
failF("%v: %v", cfg.testName, err)
return fmt.Errorf("%v: %w", cfg.testName, err)
}
passF(cfg.testName)
return nil
}
func presignedAuthHandler(s *S3Conf, testName string, handler func(client *s3.PresignClient, bucket string) error) error {
runF(testName)
bucket := getBucketName()
err := setup(s, bucket)
if err != nil {
failF("%v: %v", testName, err)
return fmt.Errorf("%v: %w", testName, err)
}
clt := s.GetPresignClient()
err = handler(clt, bucket)
if err != nil {
failF("%v: %v", testName, err)
return fmt.Errorf("%v: %w", testName, err)
}
err = teardown(s, bucket)
if err != nil {
failF("%v: %v", testName, err)
return fmt.Errorf("%v: %w", testName, err)
}
passF(testName)
return nil
}
func createSignedReq(method, endpoint, path, access, secret, service, region, overrideSha256 string, body []byte, date time.Time, headers map[string]string) (*http.Request, error) {
req, err := http.NewRequest(method, fmt.Sprintf("%v/%v", endpoint, path), bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to send the request: %w", err)
}
signer := v4.NewSigner()
hexPayload := overrideSha256
if hexPayload == "" {
hashedPayload := sha256.Sum256(body)
hexPayload = hex.EncodeToString(hashedPayload[:])
}
req.Header.Set("X-Amz-Content-Sha256", hexPayload)
for key, val := range headers {
req.Header.Add(key, val)
}
signErr := signer.SignHTTP(req.Context(), aws.Credentials{AccessKeyID: access, SecretAccessKey: secret}, req, hexPayload, service, region, date)
if signErr != nil {
return nil, fmt.Errorf("failed to sign the request: %w", signErr)
}
return req, nil
}
type APIErrorResponse struct {
XMLName xml.Name `xml:"Error"`
Code string
Message string
ArgumentName string `xml:"ArgumentName,omitempty"`
ArgumentValue string `xml:"ArgumentValue,omitempty"`
Method string `xml:"Method,omitempty"`
Resource string `xml:"Resource,omitempty"`
ResourceType s3err.ResourceType
CalculatedDigest string `xml:"CalculatedDigest,omitempty"`
ExpectedDigest string `xml:"ExpectedDigest,omitempty"`
BucketName string `xml:"BucketName,omitempty"`
ClientComputedContentSHA256 string `xml:"ClientComputedContentSHA256,omitempty"`
S3ComputedContentSHA256 string `xml:"S3ComputedContentSHA256,omitempty"`
ProposedSize int64 `xml:"ProposedSize,omitempty"`
MaxSizeAllowed int64 `xml:"MaxSizeAllowed,omitempty"`
MinSizeAllowed int64 `xml:"MinSizeAllowed,omitempty"`
ServerTime string `xml:"ServerTime,omitempty"`
XAmzExpires int `xml:"X-Amz-Expires,omitempty"`
Expires string `xml:"Expires,omitempty"`
AWSAccessKeyId string `xml:"AWSAccessKeyId,omitempty"`
Chunk int `xml:"Chunk,omitempty"`
BadChunkSize int64 `xml:"BadChunkSize,omitempty"`
ContentMD5 string `xml:"Content-MD5,omitempty"`
LocationConstraint string `xml:"LocationConstraint,omitempty"`
UploadId string `xml:"UploadId,omitempty"`
PartNumber int32 `xml:"PartNumber,omitempty"`
ETag string `xml:"ETag,omitempty"`
ActualPartCount int32 `xml:"ActualPartCount,omitempty"`
PartNumberRequested int32 `xml:"PartNumberRequested,omitempty"`
RangeRequested string `xml:"RangeRequested,omitempty"`
ActualObjectSize int64 `xml:"ActualObjectSize,omitempty"`
TagKey string `xml:"TagKey,omitempty"`
TagValue string `xml:"TagValue,omitempty"`
Size int64 `xml:"Size,omitempty"`
Header string `xml:"Header,omitempty"`
AdditionalMessage s3err.NmpAdditionalMessage `xml:"additionalMessage,omitempty"`
Key string `xml:"Key,omitempty"`
VersionId string `xml:"VersionId,omitempty"`
Condition s3err.Condition `xml:"Condition,omitempty"`
RequestTime string `xml:"RequestTime,omitempty"`
MaxAllowedSkewMilliseconds int `xml:"MaxAllowedSkewMilliseconds,omitempty"`
Region string `xml:"Region,omitempty"`
StringToSign string `xml:"StringToSign,omitempty"`
SignatureProvided string `xml:"SignatureProvided,omitempty"`
StringToSignBytes string `xml:"StringToSignBytes,omitempty"`
CanonicalRequest string `xml:"CanonicalRequest,omitempty"`
CanonicalRequestBytes string `xml:"CanonicalRequestBytes,omitempty"`
RequestID string `xml:"RequestId,omitempty"`
HostID string `xml:"HostId,omitempty"`
}
func checkHTTPResponseApiErr(resp *http.Response, expected s3err.S3Error) error {
apiErr := expected.BaseError()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
var errResp APIErrorResponse
err = xml.Unmarshal(body, &errResp)
if err != nil {
return err
}
if resp.StatusCode != apiErr.HTTPStatusCode {
return fmt.Errorf("expected response status code to be %v, instead got %v", apiErr.HTTPStatusCode, resp.StatusCode)
}
return compareS3ApiError(expected, &errResp)
}
func compareS3ApiError(expected s3err.S3Error, received *APIErrorResponse) error {
apiErr := expected.BaseError()
if received == nil {
return fmt.Errorf("expected %w, received nil", apiErr)
}
if received.Code != apiErr.Code {
return fmt.Errorf("expected error code to be %v, instead got %v", apiErr.Code, received.Code)
}
if received.Message != apiErr.Description {
return fmt.Errorf("expected error message to be %v, instead got %v", apiErr.Description, received.Message)
}
return compareS3ApiErr(expected, received)
}
func compareS3ApiErr(expected s3err.S3Error, received *APIErrorResponse) error {
switch err := expected.(type) {
case s3err.APIError:
return nil
case s3err.AccessForbiddenError:
return compareS3ApiErrFields(
compareErrField("Method", err.Method, received.Method),
compareErrField("ResourceType", err.ResourceType, received.ResourceType),
)
case s3err.BadDigestError:
return compareS3ApiErrFields(
compareErrField("CalculatedDigest", err.CalculatedDigest, received.CalculatedDigest),
compareErrField("ExpectedDigest", err.ExpectedDigest, received.ExpectedDigest),
)
case s3err.BucketError:
return compareErrField("BucketName", err.BucketName, received.BucketName)
case s3err.ContentSHA256MismatchError:
return compareS3ApiErrFields(
compareErrField("ClientComputedContentSHA256", err.ClientComputedContentSHA256, received.ClientComputedContentSHA256),
compareErrField("S3ComputedContentSHA256", err.S3ComputedContentSHA256, received.S3ComputedContentSHA256),
)
case s3err.EntityTooLargeError:
return compareS3ApiErrFields(
compareErrField("ProposedSize", err.ProposedSize, received.ProposedSize),
compareErrField("MaxSizeAllowed", err.MaxSizeAllowed, received.MaxSizeAllowed),
)
case s3err.EntityTooSmallError:
return compareS3ApiErrFields(
compareErrField("ProposedSize", err.ProposedSize, received.ProposedSize),
compareErrField("MinSizeAllowed", err.MinSizeAllowed, received.MinSizeAllowed),
)
case s3err.ExpiredPresignedURLError:
return compareS3ApiErrFields(
checkErrFieldEmptiness("ServerTime", received.ServerTime, true),
compareErrField("X-Amz-Expires", err.XAmzExpires, received.XAmzExpires),
checkErrFieldEmptiness("Expires", received.Expires, true),
)
case s3err.InvalidAccessKeyIdError:
return compareErrField("AWSAccessKeyId", err.AWSAccessKeyId, received.AWSAccessKeyId)
case s3err.InvalidArgumentError:
return compareS3ApiErrFields(
compareErrField("ArgumentName", err.ArgumentName, received.ArgumentName),
compareErrField("ArgumentValue", err.ArgumentValue, received.ArgumentValue),
)
case s3err.InvalidChunkSizeError:
return compareS3ApiErrFields(
compareErrField("Chunk", err.Chunk, received.Chunk),
compareErrField("BadChunkSize", err.BadChunkSize, received.BadChunkSize),
)
case s3err.InvalidDigestError:
return compareErrField("Content-MD5", err.ContentMD5, received.ContentMD5)
case s3err.InvalidLocationConstraintError:
return compareErrField("LocationConstraint", err.LocationConstraint, received.LocationConstraint)
case s3err.InvalidPartError:
return compareS3ApiErrFields(
compareErrField("UploadId", err.UploadId, received.UploadId),
compareErrField("PartNumber", err.PartNumber, received.PartNumber),
compareErrField("ETag", err.ETag, received.ETag),
)
case s3err.InvalidPartNumberRangeError:
return compareS3ApiErrFields(
compareErrField("ActualPartCount", err.ActualPartCount, received.ActualPartCount),
compareErrField("PartNumberRequested", err.PartNumberRequested, received.PartNumberRequested),
)
case s3err.InvalidRangeError:
return compareS3ApiErrFields(
compareErrField("RangeRequested", err.RangeRequested, received.RangeRequested),
compareErrField("ActualObjectSize", err.ActualObjectSize, received.ActualObjectSize),
)
case s3err.InvalidTagError:
return compareS3ApiErrFields(
compareErrField("TagKey", err.TagKey, received.TagKey),
compareErrField("TagValue", err.TagValue, received.TagValue),
)
case s3err.KeyTooLongError:
return compareS3ApiErrFields(
compareErrField("Size", err.Size, received.Size),
compareErrField("MaxSizeAllowed", err.MaxSizeAllowed, received.MaxSizeAllowed),
)
case s3err.MetadataTooLargeError:
return compareS3ApiErrFields(
compareErrField("Size", int64(err.Size), received.Size),
compareErrField("MaxSizeAllowed", int64(err.MaxSizeAllowed), received.MaxSizeAllowed),
)
case s3err.MethodNotAllowedError:
return compareS3ApiErrFields(
compareErrField("Method", err.Method, received.Method),
compareErrField("ResourceType", err.ResourceType, received.ResourceType),
)
case s3err.MalformedAuthError:
return compareErrField("Region", err.Region, received.Region)
case s3err.NoSuchUploadError:
return compareErrField("UploadId", err.UploadId, received.UploadId)
case s3err.NoSuchVersionError:
return compareS3ApiErrFields(
compareErrField("Key", err.Key, received.Key),
compareErrField("VersionId", err.VersionId, received.VersionId),
)
case s3err.NotImplementedError:
return compareS3ApiErrFields(
compareErrField("Header", err.Header, received.Header),
compareErrField("additionalMessage", err.AdditionalMessage, received.AdditionalMessage),
)
case s3err.PreconditionFailedError:
return compareErrField("Condition", err.Condition, received.Condition)
case s3err.RequestTimeTooSkewedError:
return compareS3ApiErrFields(
checkErrFieldEmptiness("RequestTime", received.RequestTime, true),
checkErrFieldEmptiness("ServerTime", received.ServerTime, true),
compareErrField("MaxAllowedSkewMilliseconds", err.MaxAllowedSkewMilliseconds, received.MaxAllowedSkewMilliseconds),
)
case s3err.SignatureDoesNotMatchError:
return compareS3ApiErrFields(
compareErrField("AWSAccessKeyId", err.AWSAccessKeyId, received.AWSAccessKeyId),
checkErrFieldEmptiness("StringToSign", err.StringToSign, true),
compareErrField("SignatureProvided", err.SignatureProvided, received.SignatureProvided),
checkErrFieldEmptiness("StringToSignBytes", err.StringToSignBytes, true),
checkErrFieldEmptiness("CanonicalRequest", err.CanonicalRequest, true),
checkErrFieldEmptiness("CanonicalRequestBytes", err.CanonicalRequestBytes, true),
)
}
return nil
}
func compareS3ApiErrFields(errs ...error) error {
for _, err := range errs {
if err != nil {
return err
}
}
return nil
}
func compareErrField[T comparable](field string, expected, received T) error {
if received != expected {
return fmt.Errorf("expected error %s to be %v, instead got %v", field, expected, received)
}
return nil
}
func checkErrFieldEmptiness(field, received string, require bool) error {
if require && received == "" {
return fmt.Errorf("expected error %s to be non-empty", field)
}
if !require && received != "" {
return fmt.Errorf("expected error %s to be empty, instead got %s", field, received)
}
return nil
}
func checkApiErr(err error, expected s3err.S3Error) error {
apiErr := expected.BaseError()
if err == nil {
return fmt.Errorf("expected %v, instead got nil", apiErr.Code)
}
var ae smithy.APIError
if errors.As(err, &ae) {
if ae.ErrorCode() != apiErr.Code {
return fmt.Errorf("expected error code to be %v, instead got %v", apiErr.Code, ae.ErrorCode())
}
if ae.ErrorMessage() != apiErr.Description {
return fmt.Errorf("expected error message to be %v, instead got %v", apiErr.Description, ae.ErrorMessage())
}
return nil
}
return fmt.Errorf("expected aws api error, instead got: %w", err)
}
func checkSdkApiErr(err error, code string) error {
var ae smithy.APIError
if errors.As(err, &ae) {
if ae.ErrorCode() != code {
return fmt.Errorf("expected %v, instead got %v", code, ae.ErrorCode())
}
return nil
}
return err
}
func putObjects(client *s3.Client, objs []string, bucket string) ([]types.Object, error) {
var contents []types.Object
var size int64
for _, key := range objs {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
res, err := client.PutObject(ctx, &s3.PutObjectInput{
Key: &key,
Bucket: &bucket,
})
cancel()
if err != nil {
return nil, err
}
k := key
contents = append(contents, types.Object{
Key: &k,
ETag: res.ETag,
StorageClass: types.ObjectStorageClassStandard,
Size: &size,
})
}
sort.SliceStable(contents, func(i, j int) bool {
return *contents[i].Key < *contents[j].Key
})
return contents, nil
}
func listObjects(client *s3.Client, bucket, prefix, delimiter string, maxKeys int32) ([]types.Object, []types.CommonPrefix, error) {
var contents []types.Object
var commonPrefixes []types.CommonPrefix
var continuationToken *string
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
res, err := client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: &bucket,
ContinuationToken: continuationToken,
Prefix: &prefix,
Delimiter: &delimiter,
MaxKeys: &maxKeys,
})
cancel()
if err != nil {
return nil, nil, err
}
contents = append(contents, res.Contents...)
commonPrefixes = append(commonPrefixes, res.CommonPrefixes...)
continuationToken = res.NextContinuationToken
if !*res.IsTruncated {
break
}
}
return contents, commonPrefixes, nil
}
func constructObjectLocation(endpoint, bucket, object string, hostStyle bool) string {
// Normalize endpoint (no trailing slash)
endpoint = strings.TrimRight(endpoint, "/")
if !hostStyle {
// Path-style: http://endpoint/bucket/object
return fmt.Sprintf("%s/%s/%s", endpoint, bucket, object)
}
// Host-style: http://bucket.endpoint/object
u, err := url.Parse(endpoint)
if err != nil || u.Host == "" {
// Fallback for raw host:port endpoints (e.g. "127.0.0.1:7070")
return fmt.Sprintf("http://%s.%s/%s", bucket, endpoint, object)
}
host := u.Host
u.Host = fmt.Sprintf("%s.%s", bucket, host)
return fmt.Sprintf("%s/%s", u.String(), object)
}
func hasObjNames(objs []types.Object, names []string) bool {
if len(objs) != len(names) {
return false
}
for _, obj := range objs {
if slices.Contains(names, *obj.Key) {
continue
}
return false
}
return true
}
func hasPrefixName(prefixes []types.CommonPrefix, names []string) bool {
if len(prefixes) != len(names) {
return false
}
for _, prefix := range prefixes {
if slices.Contains(names, *prefix.Prefix) {
continue
}
return false
}
return true
}
type putObjectCfg struct {
checksumAlgorithm types.ChecksumAlgorithm
}
type putObjectOpt func(*putObjectCfg)
func withPutObjectChecksumAlgo(algo types.ChecksumAlgorithm) putObjectOpt {
return func(poc *putObjectCfg) { poc.checksumAlgorithm = algo }
}
type putObjectOutput struct {
csum [32]byte
data []byte
res *s3.PutObjectOutput
}
func putObjectWithData(lgth int64, input *s3.PutObjectInput, client *s3.Client, opts ...putObjectOpt) (*putObjectOutput, error) {
cfg := &putObjectCfg{}
for _, opt := range opts {
opt(cfg)
}
var csum [32]byte
var data []byte
if input.Body == nil && lgth != 0 {
data = make([]byte, lgth)
rand.Read(data)
csum = sha256.Sum256(data)
if cfg.checksumAlgorithm != "" {
hasher, err := NewHasher(cfg.checksumAlgorithm)
if err != nil {
return nil, err
}
hasher.Write(data)
sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil))
setPutObjectChecksum(input, cfg.checksumAlgorithm, &sum)
}
input.Body = bytes.NewReader(data)
}
ctx, cancel := context.WithTimeout(context.Background(), longTimeout)
res, err := client.PutObject(ctx, input, func(o *s3.Options) {
// if input.Body is not nil, aws sdk hardcodes Content-Type: application/octet-stream
// this adds a new middleware in the stack to remove the Content-Type header, if
// it isn't explicitly provided as 'PutObject' input. Place the middleware
// right before "Signing" middleware to avoid incorrect request signature calculation
if input.ContentType == nil {
o.APIOptions = append(o.APIOptions, func(stack *middleware.Stack) error {
return stack.Finalize.Insert(
middleware.FinalizeMiddlewareFunc("UnsetContentType",
func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
out middleware.FinalizeOutput, md middleware.Metadata, err error,
) {
if req, ok := in.Request.(*smithyhttp.Request); ok {
req.Header.Del("Content-Type")
}
return next.HandleFinalize(ctx, in)
}),
"Signing",
middleware.Before,
)
})
}
})
cancel()
if err != nil {
return nil, err
}
return &putObjectOutput{
csum: csum,
data: data,
res: res,
}, nil
}
type mpCfg struct {
checksumAlgorithm types.ChecksumAlgorithm
checksumType types.ChecksumType
metadata map[string]string
}
type mpOpt func(*mpCfg)
func withChecksum(algo types.ChecksumAlgorithm) mpOpt {
return func(mc *mpCfg) { mc.checksumAlgorithm = algo }
}
func withChecksumType(t types.ChecksumType) mpOpt {
return func(mc *mpCfg) { mc.checksumType = t }
}
func withMetadata(m map[string]string) mpOpt {
return func(mc *mpCfg) { mc.metadata = m }
}
func createMp(s3client *s3.Client, bucket, key string, opts ...mpOpt) (*s3.CreateMultipartUploadOutput, error) {
cfg := new(mpCfg)
for _, opt := range opts {
opt(cfg)
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: &bucket,
Key: &key,
ChecksumAlgorithm: cfg.checksumAlgorithm,
ChecksumType: cfg.checksumType,
Metadata: cfg.metadata,
})
cancel()
return out, err
}
func isSameData(a, b []byte) bool {
return bytes.Equal(a, b)
}
func compareMultipartUploads(list1, list2 []types.MultipartUpload) bool {
if len(list1) != len(list2) {
return false
}
for i, item := range list1 {
if *item.Key != *list2[i].Key {
return false
}
if *item.UploadId != *list2[i].UploadId {
return false
}
if item.StorageClass != list2[i].StorageClass {
return false
}
if item.ChecksumAlgorithm != list2[i].ChecksumAlgorithm {
return false
}
if item.ChecksumType != list2[i].ChecksumType {
return false
}
}
return true
}
func compareParts(parts1, parts2 []types.Part) bool {
if len(parts1) != len(parts2) {
fmt.Printf("list length are not equal: %v != %v\n", len(parts1), len(parts2))
return false
}
for i, prt := range parts1 {
if *prt.PartNumber != *parts2[i].PartNumber {
fmt.Printf("partNumbers are not equal, %v != %v\n", *prt.PartNumber, *parts2[i].PartNumber)
return false
}
if *prt.ETag != *parts2[i].ETag {
fmt.Printf("etags are not equal, %v != %v\n", *prt.ETag, *parts2[i].ETag)
return false
}
if *prt.Size != *parts2[i].Size {
fmt.Printf("sizes are not equal, %v != %v\n", *prt.Size, *parts2[i].Size)
return false
}
if prt.ChecksumCRC32 != nil {
if *prt.ChecksumCRC32 != getString(parts2[i].ChecksumCRC32) {
fmt.Printf("crc32 checksums are not equal, %v != %v\n", *prt.ChecksumCRC32, getString(parts2[i].ChecksumCRC32))
return false
}
}
if prt.ChecksumCRC32C != nil {
if *prt.ChecksumCRC32C != getString(parts2[i].ChecksumCRC32C) {
fmt.Printf("crc32c checksums are not equal, %v != %v\n", *prt.ChecksumCRC32C, getString(parts2[i].ChecksumCRC32C))
return false
}
}
if prt.ChecksumSHA1 != nil {
if *prt.ChecksumSHA1 != getString(parts2[i].ChecksumSHA1) {
fmt.Printf("sha1 checksums are not equal, %v != %v\n", *prt.ChecksumSHA1, getString(parts2[i].ChecksumSHA1))
return false
}
}
if prt.ChecksumSHA256 != nil {
if *prt.ChecksumSHA256 != getString(parts2[i].ChecksumSHA256) {
fmt.Printf("sha256 checksums are not equal, %v != %v\n", *prt.ChecksumSHA256, getString(parts2[i].ChecksumSHA256))
return false
}
}
if prt.ChecksumCRC64NVME != nil {
if *prt.ChecksumCRC64NVME != getString(parts2[i].ChecksumCRC64NVME) {
fmt.Printf("crc64nvme checksums are not equal, %v != %v\n", *prt.ChecksumCRC64NVME, getString(parts2[i].ChecksumCRC64NVME))
return false
}
}
if prt.ChecksumSHA512 != nil {
if *prt.ChecksumSHA512 != getString(parts2[i].ChecksumSHA512) {
fmt.Printf("sha512 checksums are not equal, %v != %v\n", *prt.ChecksumSHA512, getString(parts2[i].ChecksumSHA512))
return false
}
}
if prt.ChecksumMD5 != nil {
if *prt.ChecksumMD5 != getString(parts2[i].ChecksumMD5) {
fmt.Printf("md5 checksums are not equal, %v != %v\n", *prt.ChecksumMD5, getString(parts2[i].ChecksumMD5))
return false
}
}
if prt.ChecksumXXHASH64 != nil {
if *prt.ChecksumXXHASH64 != getString(parts2[i].ChecksumXXHASH64) {
fmt.Printf("xxhash64 checksums are not equal, %v != %v\n", *prt.ChecksumXXHASH64, getString(parts2[i].ChecksumXXHASH64))
return false
}
}
if prt.ChecksumXXHASH3 != nil {
if *prt.ChecksumXXHASH3 != getString(parts2[i].ChecksumXXHASH3) {
fmt.Printf("xxhash3 checksums are not equal, %v != %v\n", *prt.ChecksumXXHASH3, getString(parts2[i].ChecksumXXHASH3))
return false
}
}
if prt.ChecksumXXHASH128 != nil {
if *prt.ChecksumXXHASH128 != getString(parts2[i].ChecksumXXHASH128) {
fmt.Printf("xxhash128 checksums are not equal, %v != %v\n", *prt.ChecksumXXHASH128, getString(parts2[i].ChecksumXXHASH128))
return false
}
}
}
return true
}
func areTagsSame(tags1, tags2 []types.Tag) bool {
if len(tags1) != len(tags2) {
return false
}
for _, tag := range tags1 {
if !containsTag(tag, tags2) {
return false
}
}
return true
}
func containsTag(tag types.Tag, list []types.Tag) bool {
for _, item := range list {
if *item.Key == *tag.Key && *item.Value == *tag.Value {
return true
}
}
return false
}
func compareGrants(grts1, grts2 []types.Grant) bool {
if len(grts1) != len(grts2) {
return false
}
for i, grt := range grts1 {
if grt.Permission != grts2[i].Permission {
return false
}
if *grt.Grantee.ID != *grts2[i].Grantee.ID {
return false
}
if grt.Grantee.Type != grts2[i].Grantee.Type {
return false
}
}
return true
}
func execCommand(args ...string) ([]byte, error) {
cmd := exec.Command("./versitygw", args...)
return cmd.CombinedOutput()
}
func getString(str *string) string {
if str == nil {
return ""
}
return *str
}
func getPtr[T any](str T) *T {
return &str
}
func checksumHeaderName(algo types.ChecksumAlgorithm) string {
return fmt.Sprintf("x-amz-checksum-%s", strings.ToLower(string(algo)))
}
type checksumFields struct {
CRC32 **string
CRC32C **string
SHA1 **string
SHA256 **string
CRC64NVME **string
SHA512 **string
MD5 **string
XXHASH64 **string
XXHASH3 **string
XXHASH128 **string
}
func selectChecksum(algo types.ChecksumAlgorithm, fields checksumFields) **string {
switch algo {
case types.ChecksumAlgorithmCrc32:
return fields.CRC32
case types.ChecksumAlgorithmCrc32c:
return fields.CRC32C
case types.ChecksumAlgorithmSha1:
return fields.SHA1
case types.ChecksumAlgorithmSha256:
return fields.SHA256
case types.ChecksumAlgorithmCrc64nvme:
return fields.CRC64NVME
case types.ChecksumAlgorithmSha512:
return fields.SHA512
case types.ChecksumAlgorithmMd5:
return fields.MD5
case types.ChecksumAlgorithmXxhash64:
return fields.XXHASH64
case types.ChecksumAlgorithmXxhash3:
return fields.XXHASH3
case types.ChecksumAlgorithmXxhash128:
return fields.XXHASH128
default:
return nil
}
}
func getChecksum(algo types.ChecksumAlgorithm, fields checksumFields) *string {
if checksum := selectChecksum(algo, fields); checksum != nil {
return *checksum
}
return nil
}
func setChecksum(algo types.ChecksumAlgorithm, fields checksumFields, checksum *string) {
if selected := selectChecksum(algo, fields); selected != nil {
*selected = checksum
}
}
func getPartChecksum(part types.Part, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &part.ChecksumCRC32,
CRC32C: &part.ChecksumCRC32C,
SHA1: &part.ChecksumSHA1,
SHA256: &part.ChecksumSHA256,
CRC64NVME: &part.ChecksumCRC64NVME,
SHA512: &part.ChecksumSHA512,
MD5: &part.ChecksumMD5,
XXHASH64: &part.ChecksumXXHASH64,
XXHASH3: &part.ChecksumXXHASH3,
XXHASH128: &part.ChecksumXXHASH128,
})
}
func setPartChecksum(part *types.Part, algo types.ChecksumAlgorithm, checksum *string) {
setChecksum(algo, checksumFields{
CRC32: &part.ChecksumCRC32,
CRC32C: &part.ChecksumCRC32C,
SHA1: &part.ChecksumSHA1,
SHA256: &part.ChecksumSHA256,
CRC64NVME: &part.ChecksumCRC64NVME,
SHA512: &part.ChecksumSHA512,
MD5: &part.ChecksumMD5,
XXHASH64: &part.ChecksumXXHASH64,
XXHASH3: &part.ChecksumXXHASH3,
XXHASH128: &part.ChecksumXXHASH128,
}, checksum)
}
func getCompletedPartChecksum(part types.CompletedPart, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &part.ChecksumCRC32,
CRC32C: &part.ChecksumCRC32C,
SHA1: &part.ChecksumSHA1,
SHA256: &part.ChecksumSHA256,
CRC64NVME: &part.ChecksumCRC64NVME,
SHA512: &part.ChecksumSHA512,
MD5: &part.ChecksumMD5,
XXHASH64: &part.ChecksumXXHASH64,
XXHASH3: &part.ChecksumXXHASH3,
XXHASH128: &part.ChecksumXXHASH128,
})
}
func completedPartFromPart(part types.Part) types.CompletedPart {
return types.CompletedPart{
ETag: part.ETag,
PartNumber: part.PartNumber,
ChecksumCRC32: part.ChecksumCRC32,
ChecksumCRC32C: part.ChecksumCRC32C,
ChecksumSHA1: part.ChecksumSHA1,
ChecksumSHA256: part.ChecksumSHA256,
ChecksumCRC64NVME: part.ChecksumCRC64NVME,
ChecksumSHA512: part.ChecksumSHA512,
ChecksumMD5: part.ChecksumMD5,
ChecksumXXHASH64: part.ChecksumXXHASH64,
ChecksumXXHASH3: part.ChecksumXXHASH3,
ChecksumXXHASH128: part.ChecksumXXHASH128,
}
}
func getPutObjectChecksum(out *s3.PutObjectOutput, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func setPutObjectChecksum(in *s3.PutObjectInput, algo types.ChecksumAlgorithm, checksum *string) {
setChecksum(algo, checksumFields{
CRC32: &in.ChecksumCRC32,
CRC32C: &in.ChecksumCRC32C,
SHA1: &in.ChecksumSHA1,
SHA256: &in.ChecksumSHA256,
CRC64NVME: &in.ChecksumCRC64NVME,
SHA512: &in.ChecksumSHA512,
MD5: &in.ChecksumMD5,
XXHASH64: &in.ChecksumXXHASH64,
XXHASH3: &in.ChecksumXXHASH3,
XXHASH128: &in.ChecksumXXHASH128,
}, checksum)
}
func getGetObjectChecksum(out *s3.GetObjectOutput, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func getHeadObjectChecksum(out *s3.HeadObjectOutput, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func getObjectAttributesChecksum(out *types.Checksum, algo types.ChecksumAlgorithm) *string {
if out == nil {
return nil
}
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func getUploadPartChecksum(out *s3.UploadPartOutput, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func setUploadPartChecksum(in *s3.UploadPartInput, algo types.ChecksumAlgorithm, checksum *string) {
setChecksum(algo, checksumFields{
CRC32: &in.ChecksumCRC32,
CRC32C: &in.ChecksumCRC32C,
SHA1: &in.ChecksumSHA1,
SHA256: &in.ChecksumSHA256,
CRC64NVME: &in.ChecksumCRC64NVME,
SHA512: &in.ChecksumSHA512,
MD5: &in.ChecksumMD5,
XXHASH64: &in.ChecksumXXHASH64,
XXHASH3: &in.ChecksumXXHASH3,
XXHASH128: &in.ChecksumXXHASH128,
}, checksum)
}
func getCompleteMultipartUploadChecksum(out *s3.CompleteMultipartUploadOutput, algo types.ChecksumAlgorithm) *string {
return getChecksum(algo, checksumFields{
CRC32: &out.ChecksumCRC32,
CRC32C: &out.ChecksumCRC32C,
SHA1: &out.ChecksumSHA1,
SHA256: &out.ChecksumSHA256,
CRC64NVME: &out.ChecksumCRC64NVME,
SHA512: &out.ChecksumSHA512,
MD5: &out.ChecksumMD5,
XXHASH64: &out.ChecksumXXHASH64,
XXHASH3: &out.ChecksumXXHASH3,
XXHASH128: &out.ChecksumXXHASH128,
})
}
func setCompleteMultipartUploadChecksum(in *s3.CompleteMultipartUploadInput, algo types.ChecksumAlgorithm, checksum *string) {
setChecksum(algo, checksumFields{
CRC32: &in.ChecksumCRC32,
CRC32C: &in.ChecksumCRC32C,
SHA1: &in.ChecksumSHA1,
SHA256: &in.ChecksumSHA256,
CRC64NVME: &in.ChecksumCRC64NVME,
SHA512: &in.ChecksumSHA512,
MD5: &in.ChecksumMD5,
XXHASH64: &in.ChecksumXXHASH64,
XXHASH3: &in.ChecksumXXHASH3,
XXHASH128: &in.ChecksumXXHASH128,
}, checksum)
}
func getCopyObjectChecksum(result *types.CopyObjectResult, algo types.ChecksumAlgorithm) *string {
if result == nil {
return nil
}
return getChecksum(algo, checksumFields{
CRC32: &result.ChecksumCRC32,
CRC32C: &result.ChecksumCRC32C,
SHA1: &result.ChecksumSHA1,
SHA256: &result.ChecksumSHA256,
CRC64NVME: &result.ChecksumCRC64NVME,
SHA512: &result.ChecksumSHA512,
MD5: &result.ChecksumMD5,
XXHASH64: &result.ChecksumXXHASH64,
XXHASH3: &result.ChecksumXXHASH3,
XXHASH128: &result.ChecksumXXHASH128,
})
}
func getUploadPartCopyChecksum(result *types.CopyPartResult, algo types.ChecksumAlgorithm) *string {
if result == nil {
return nil
}
return getChecksum(algo, checksumFields{
CRC32: &result.ChecksumCRC32,
CRC32C: &result.ChecksumCRC32C,
SHA1: &result.ChecksumSHA1,
SHA256: &result.ChecksumSHA256,
CRC64NVME: &result.ChecksumCRC64NVME,
SHA512: &result.ChecksumSHA512,
MD5: &result.ChecksumMD5,
XXHASH64: &result.ChecksumXXHASH64,
XXHASH3: &result.ChecksumXXHASH3,
XXHASH128: &result.ChecksumXXHASH128,
})
}
// mp1 needs to be the response from the server
// mp2 needs to be the expected values
// The keys from the server are always converted to lowercase
func areMapsSame(mp1, mp2 map[string]string) bool {
if len(mp1) != len(mp2) {
return false
}
for key, val := range mp2 {
if mp1[key] != val {
return false
}
}
return true
}
func compareBuckets(list1 []types.Bucket, list2 []types.Bucket) bool {
if len(list1) != len(list2) {
return false
}
for i, elem := range list1 {
if *elem.Name != *list2[i].Name {
fmt.Printf("bucket names are not equal: %s != %s\n", *elem.Name, *list2[i].Name)
return false
}
if *elem.BucketRegion != *list2[i].BucketRegion {
fmt.Printf("bucket regions are not equal: %s != %s\n", *elem.BucketRegion, *list2[i].BucketRegion)
return false
}
}
return true
}
func compareObjects(list1, list2 []types.Object) bool {
if len(list1) != len(list2) {
fmt.Println("list lengths are not equal")
return false
}
for i, obj := range list1 {
if *obj.Key != *list2[i].Key {
fmt.Printf("keys are not equal: %q != %q\n", *obj.Key, *list2[i].Key)
return false
}
if *obj.ETag != *list2[i].ETag {
fmt.Printf("etags are not equal: (%q %q) %q != %q\n",
*obj.Key, *list2[i].Key, *obj.ETag, *list2[i].ETag)
return false
}
if *obj.Size != *list2[i].Size {
fmt.Printf("sizes are not equal: (%q %q) %v != %v\n",
*obj.Key, *list2[i].Key, *obj.Size, *list2[i].Size)
return false
}
if obj.StorageClass != list2[i].StorageClass {
fmt.Printf("storage classes are not equal: (%q %q) %v != %v\n",
*obj.Key, *list2[i].Key, obj.StorageClass, list2[i].StorageClass)
return false
}
if len(obj.ChecksumAlgorithm) != 0 {
if obj.ChecksumAlgorithm[0] != list2[i].ChecksumAlgorithm[0] {
fmt.Printf("checksum algorithms are not equal: (%q %q) %v != %v\n",
*obj.Key, *list2[i].Key, obj.ChecksumAlgorithm[0], list2[i].ChecksumAlgorithm[0])
return false
}
}
if obj.ChecksumType != "" {
if obj.ChecksumType[0] != list2[i].ChecksumType[0] {
fmt.Printf("checksum types are not equal: (%q %q) %v != %v\n",
*obj.Key, *list2[i].Key, obj.ChecksumType, list2[i].ChecksumType)
return false
}
}
if obj.Owner != nil {
if *obj.Owner.ID != *list2[i].Owner.ID {
fmt.Printf("object owner IDs not equal: (%q %q) %v != %v\n",
*obj.Key, *list2[i].Key, *obj.Owner.ID, *list2[i].Owner.ID)
}
}
}
return true
}
func comparePrefixes(list1 []string, list2 []types.CommonPrefix) bool {
if len(list1) != len(list2) {
return false
}
for i, prefix := range list1 {
if list2[i].Prefix == nil {
fmt.Printf("unexpected nil prefix on index %v", i)
return false
}
if *list2[i].Prefix != prefix {
fmt.Printf("prefix mismatch on index %v: expected %s, got %v", i, prefix, *list2[i].Prefix)
return false
}
}
return true
}
func compareDelObjects(list1, list2 []types.DeletedObject) bool {
if len(list1) != len(list2) {
return false
}
for i, obj := range list1 {
if *obj.Key != *list2[i].Key {
return false
}
if obj.VersionId != nil {
if list2[i].VersionId == nil {
return false
}
if *obj.VersionId != *list2[i].VersionId {
return false
}
}
if obj.DeleteMarkerVersionId != nil {
if list2[i].DeleteMarkerVersionId == nil {
return false
}
if *obj.DeleteMarkerVersionId != *list2[i].DeleteMarkerVersionId {
return false
}
}
if obj.DeleteMarker != nil {
if list2[i].DeleteMarker == nil {
return false
}
if *obj.DeleteMarker != *list2[i].DeleteMarker {
return false
}
}
}
return true
}
func uploadParts(client *s3.Client, size, partCount int64, bucket, key, uploadId string, opts ...mpOpt) (parts []types.Part, csum string, err error) {
partSize := size / partCount
var objHasher hash.Hash
var partHasher hash.Hash
cfg := new(mpCfg)
for _, opt := range opts {
opt(cfg)
}
switch cfg.checksumAlgorithm {
case types.ChecksumAlgorithmCrc32:
objHasher = crc32.NewIEEE()
case types.ChecksumAlgorithmCrc32c:
objHasher = crc32.New(crc32.MakeTable(crc32.Castagnoli))
case types.ChecksumAlgorithmMd5:
objHasher = md5.New()
case types.ChecksumAlgorithmSha1:
objHasher = sha1.New()
case types.ChecksumAlgorithmSha256:
objHasher = sha256.New()
case types.ChecksumAlgorithmSha512:
objHasher = sha512.New()
case types.ChecksumAlgorithmCrc64nvme:
objHasher = crc64.New(crc64.MakeTable(bits.Reverse64(0xad93d23594c93659)))
case types.ChecksumAlgorithmXxhash64:
objHasher = xxhash.New()
case types.ChecksumAlgorithmXxhash3:
objHasher = xxh3.New()
case types.ChecksumAlgorithmXxhash128:
objHasher = xxh3.New128()
default:
objHasher = sha256.New()
}
if cfg.checksumAlgorithm != "" {
partHasher, err = NewHasher(cfg.checksumAlgorithm)
if err != nil {
return nil, "", err
}
}
for partNumber := int64(1); partNumber <= partCount; partNumber++ {
partStart := (partNumber - 1) * partSize
partEnd := partStart + partSize - 1
if partEnd > size-1 {
partEnd = size - 1
}
partBuffer := make([]byte, partEnd-partStart+1)
rand.Read(partBuffer)
objHasher.Write(partBuffer)
if partHasher != nil {
partHasher.Write(partBuffer)
}
pn := int32(partNumber)
input := &s3.UploadPartInput{
Bucket: &bucket,
Key: &key,
UploadId: &uploadId,
Body: bytes.NewReader(partBuffer),
PartNumber: &pn,
ChecksumAlgorithm: cfg.checksumAlgorithm,
}
if partHasher != nil {
partChecksum := base64.StdEncoding.EncodeToString(partHasher.Sum(nil))
setUploadPartChecksum(input, cfg.checksumAlgorithm, &partChecksum)
partHasher.Reset()
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := client.UploadPart(ctx, input)
cancel()
if err != nil {
return parts, "", err
}
part := types.Part{
ETag: out.ETag,
PartNumber: &pn,
Size: &partSize,
}
setPartChecksum(&part, cfg.checksumAlgorithm, getUploadPartChecksum(out, cfg.checksumAlgorithm))
parts = append(parts, part)
}
sum := objHasher.Sum(nil)
if cfg.checksumAlgorithm == "" {
csum = hex.EncodeToString(sum[:])
} else {
csum = base64.StdEncoding.EncodeToString(sum[:])
}
return parts, csum, err
}
func createUsers(s *S3Conf, users []user) error {
for _, usr := range users {
out, err := execCommand(s.getAdminCommand("-a", s.awsID, "-s", s.awsSecret, "-er", s.endpoint, "create-user", "-a", usr.access, "-s", usr.secret, "-r", usr.role)...)
if err != nil {
return err
}
if strings.Contains(string(out), adminErrorPrefix) {
return fmt.Errorf("failed to create user account: %s", out)
}
}
return nil
}
func changeBucketsOwner(s *S3Conf, buckets []string, owner string) error {
for _, bucket := range buckets {
out, err := execCommand(s.getAdminCommand("-a", s.awsID, "-s", s.awsSecret, "-er", s.endpoint, "change-bucket-owner", "-b", bucket, "-o", owner)...)
if err != nil {
return err
}
if strings.Contains(string(out), adminErrorPrefix) {
return fmt.Errorf("failed to change the bucket owner: %s", out)
}
}
return nil
}
func listBuckets(s *S3Conf) error {
out, err := execCommand(s.getAdminCommand("-a", s.awsID, "-s", s.awsSecret, "-er", s.endpoint, "list-buckets")...)
if err != nil {
return err
}
if strings.Contains(string(out), adminErrorPrefix) {
return fmt.Errorf("failed to list buckets, %s", out)
}
return nil
}
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func genRandString(length int) string {
source := rnd.NewSource(time.Now().UnixNano())
random := rnd.New(source)
result := make([]byte, length)
for i := range result {
result[i] = charset[random.Intn(len(charset))]
}
return string(result)
}
const (
credAccess int = iota
credDate
credRegion
credService
credTerminator
)
func changeAuthCred(uri, newVal string, index int) (string, error) {
urlParsed, err := url.Parse(uri)
if err != nil {
return "", err
}
queries := urlParsed.Query()
creds := strings.Split(queries.Get("X-Amz-Credential"), "/")
creds[index] = newVal
queries.Set("X-Amz-Credential", strings.Join(creds, "/"))
urlParsed.RawQuery = queries.Encode()
return urlParsed.String(), nil
}
func genPolicyDoc(effect, principal, action, resource string) string {
jsonTemplate := `{
"Statement": [
{
"Effect": "%s",
"Principal": %s,
"Action": %s,
"Resource": %s
}
]
}
`
return fmt.Sprintf(jsonTemplate, effect, principal, action, resource)
}
type policyType string
const (
policyTypeBucket policyType = "bucket"
policyTypeObject policyType = "object"
policyTypeFull policyType = "full"
)
func grantPublicBucketPolicy(client *s3.Client, bucket string, tp policyType) error {
var doc string
switch tp {
case policyTypeBucket:
doc = genPolicyDoc("Allow", `"*"`, `"s3:*"`, fmt.Sprintf(`"arn:aws:s3:::%s"`, bucket))
case policyTypeObject:
doc = genPolicyDoc("Allow", `"*"`, `"s3:*"`, fmt.Sprintf(`"arn:aws:s3:::%s/*"`, bucket))
case policyTypeFull:
template := `{
"Statement": [
{
"Effect": "Allow",
"Principal": "*",
"Action": "s3:*",
"Resource": "arn:aws:s3:::%s"
},
{
"Effect": "Allow",
"Principal": "*",
"Action": "s3:*",
"Resource": "arn:aws:s3:::%s/*"
}
]
}
`
doc = fmt.Sprintf(template, bucket, bucket)
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutBucketPolicy(ctx, &s3.PutBucketPolicyInput{
Bucket: &bucket,
Policy: &doc,
})
cancel()
return err
}
func getMalformedPolicyError(msg string) s3err.APIError {
return s3err.APIError{
Code: "MalformedPolicy",
Description: msg,
HTTPStatusCode: http.StatusBadRequest,
}
}
func putBucketVersioningStatus(client *s3.Client, bucket string, status types.BucketVersioningStatus) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutBucketVersioning(ctx, &s3.PutBucketVersioningInput{
Bucket: &bucket,
VersioningConfiguration: &types.VersioningConfiguration{
Status: status,
},
})
cancel()
return err
}
func checkWORMProtection(client *s3.Client, bucket, object string) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &bucket,
Key: &object,
})
cancel()
if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrObjectLocked)); err != nil {
return err
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &bucket,
Key: &object,
})
cancel()
if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrObjectLocked)); err != nil {
return err
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
Bucket: &bucket,
Delete: &types.Delete{
Objects: []types.ObjectIdentifier{
{
Key: &object,
},
},
},
})
cancel()
if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrObjectLocked)); err != nil {
return err
}
return nil
}
func objStrings(objs []types.Object) []string {
objStrs := make([]string, len(objs))
for i, obj := range objs {
objStrs[i] = *obj.Key
}
return objStrs
}
func pfxStrings(pfxs []types.CommonPrefix) []string {
pfxStrs := make([]string, len(pfxs))
for i, pfx := range pfxs {
pfxStrs[i] = *pfx.Prefix
}
return pfxStrs
}
type versCfg struct {
checksumAlgorithm types.ChecksumAlgorithm
}
type versOpt func(*versCfg)
func withChecksumAlgo(algo types.ChecksumAlgorithm) versOpt {
return func(vc *versCfg) { vc.checksumAlgorithm = algo }
}
func createObjVersions(client *s3.Client, bucket, object string, count int, opts ...versOpt) ([]types.ObjectVersion, error) {
cfg := new(versCfg)
for _, o := range opts {
o(cfg)
}
versions := []types.ObjectVersion{}
for i := range count {
rNumber, err := rand.Int(rand.Reader, big.NewInt(100000))
dataLength := rNumber.Int64()
if err != nil {
return nil, err
}
r, err := putObjectWithData(dataLength, &s3.PutObjectInput{
Bucket: &bucket,
Key: &object,
}, client)
if err != nil {
return nil, err
}
isLatest := i == count-1
version := types.ObjectVersion{
ETag: r.res.ETag,
IsLatest: &isLatest,
Key: &object,
Size: &dataLength,
VersionId: r.res.VersionId,
StorageClass: types.ObjectVersionStorageClassStandard,
ChecksumType: r.res.ChecksumType,
}
switch {
case r.res.ChecksumCRC32 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmCrc32,
}
case r.res.ChecksumCRC32C != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmCrc32c,
}
case r.res.ChecksumCRC64NVME != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmCrc64nvme,
}
case r.res.ChecksumSHA1 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmSha1,
}
case r.res.ChecksumSHA256 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmSha256,
}
case r.res.ChecksumSHA512 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmSha512,
}
case r.res.ChecksumMD5 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmMd5,
}
case r.res.ChecksumXXHASH64 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmXxhash64,
}
case r.res.ChecksumXXHASH3 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmXxhash3,
}
case r.res.ChecksumXXHASH128 != nil:
version.ChecksumAlgorithm = []types.ChecksumAlgorithm{
types.ChecksumAlgorithmXxhash128,
}
}
versions = append(versions, version)
}
versions = reverseSlice(versions)
return versions, nil
}
// ReverseSlice reverses a slice of any type
func reverseSlice[T any](s []T) []T {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
return s
}
func compareVersions(v1, v2 []types.ObjectVersion) bool {
if len(v1) != len(v2) {
return false
}
for i, version := range v1 {
if version.Key == nil || v2[i].Key == nil {
return false
}
if *version.Key != *v2[i].Key {
return false
}
if version.VersionId == nil || v2[i].VersionId == nil {
return false
}
if *version.VersionId != *v2[i].VersionId {
return false
}
if version.IsLatest == nil || v2[i].IsLatest == nil {
return false
}
if *version.IsLatest != *v2[i].IsLatest {
return false
}
if version.Size == nil || v2[i].Size == nil {
return false
}
if *version.Size != *v2[i].Size {
return false
}
if version.ETag == nil || v2[i].ETag == nil {
return false
}
if *version.ETag != *v2[i].ETag {
return false
}
if version.StorageClass != v2[i].StorageClass {
return false
}
if version.ChecksumType != "" {
if version.ChecksumType != v2[i].ChecksumType {
return false
}
}
if len(version.ChecksumAlgorithm) != 0 {
if len(v2[i].ChecksumAlgorithm) == 0 {
return false
}
if version.ChecksumAlgorithm[0] != v2[i].ChecksumAlgorithm[0] {
return false
}
}
}
return true
}
func compareDelMarkers(d1, d2 []types.DeleteMarkerEntry) bool {
if len(d1) != len(d2) {
return false
}
for i, dEntry := range d1 {
if dEntry.Key == nil || d2[i].Key == nil {
return false
}
if *dEntry.Key != *d2[i].Key {
return false
}
if dEntry.IsLatest == nil || d2[i].IsLatest == nil {
return false
}
if *dEntry.IsLatest != *d2[i].IsLatest {
return false
}
if dEntry.VersionId == nil || d2[i].VersionId == nil {
return false
}
if *dEntry.VersionId != *d2[i].VersionId {
return false
}
}
return true
}
type ObjectMetaProps struct {
ContentLength int64
ContentType string
ContentEncoding string
ContentDisposition string
ContentLanguage string
CacheControl string
ExpiresString string
Metadata map[string]string
}
func checkObjectMetaProps(client *s3.Client, bucket, object string, o ObjectMetaProps) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: &bucket,
Key: &object,
})
cancel()
if err != nil {
return err
}
if o.Metadata != nil {
if !areMapsSame(out.Metadata, o.Metadata) {
return fmt.Errorf("expected the object metadata to be %v, instead got %v", o.Metadata, out.Metadata)
}
}
if out.ContentLength == nil {
return fmt.Errorf("expected Content-Length %v, instead got nil", o.ContentLength)
}
if *out.ContentLength != o.ContentLength {
return fmt.Errorf("expected Content-Length %v, instead got %v", o.ContentLength, *out.ContentLength)
}
if o.ContentType != "" && getString(out.ContentType) != o.ContentType {
return fmt.Errorf("expected Content-Type %v, instead got %v", o.ContentType, getString(out.ContentType))
}
if o.ContentDisposition != "" && getString(out.ContentDisposition) != o.ContentDisposition {
return fmt.Errorf("expected Content-Disposition %v, instead got %v", o.ContentDisposition, getString(out.ContentDisposition))
}
if o.ContentEncoding != "" && getString(out.ContentEncoding) != o.ContentEncoding {
return fmt.Errorf("expected Content-Encoding %v, instead got %v", o.ContentEncoding, getString(out.ContentEncoding))
}
if o.ContentLanguage != "" && getString(out.ContentLanguage) != o.ContentLanguage {
return fmt.Errorf("expected Content-Language %v, instead got %v", o.ContentLanguage, getString(out.ContentLanguage))
}
if o.CacheControl != "" && getString(out.CacheControl) != o.CacheControl {
return fmt.Errorf("expected Cache-Control %v, instead got %v", o.CacheControl, getString(out.CacheControl))
}
if o.ExpiresString != "" && getString(out.ExpiresString) != o.ExpiresString {
return fmt.Errorf("expected Expires %v, instead got %v", o.ExpiresString, getString(out.ExpiresString))
}
if out.StorageClass != types.StorageClassStandard {
return fmt.Errorf("expected the storage class to be %v, instead got %v", types.StorageClassStandard, out.StorageClass)
}
return nil
}
func getBoolPtr(b bool) *bool {
return &b
}
type PublicBucketTestCase struct {
Action string
Call func(ctx context.Context) error
ExpectedErr error
}
// randomizeCase randomizes the provided string latters case
func randomizeCase(s string) string {
var b strings.Builder
for _, ch := range s {
if rnd.Intn(2) == 0 {
b.WriteRune(unicode.ToLower(ch))
} else {
b.WriteRune(unicode.ToUpper(ch))
}
}
return b.String()
}
func headObject_zero_len_with_range_helper(testName, obj string, s *S3Conf) error {
return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error {
objLength := int64(0)
_, err := putObjectWithData(objLength, &s3.PutObjectInput{
Bucket: &bucket,
Key: &obj,
}, s3client)
if err != nil {
return err
}
testRange := func(rg, contentRange string, cLength int64, expectErr bool) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
res, err := s3client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: &bucket,
Key: &obj,
Range: &rg,
})
cancel()
if err == nil && expectErr {
return fmt.Errorf("%v: expected err 'RequestedRangeNotSatisfiable' error, instead got nil", rg)
}
if err != nil {
if !expectErr {
return err
}
var ae smithy.APIError
if errors.As(err, &ae) {
if ae.ErrorCode() != "RequestedRangeNotSatisfiable" {
return fmt.Errorf("%v: expected RequestedRangeNotSatisfiable, instead got %v", rg, ae.ErrorCode())
}
if ae.ErrorMessage() != "Requested Range Not Satisfiable" {
return fmt.Errorf("%v: expected the error message to be 'Requested Range Not Satisfiable', instead got %v", rg, ae.ErrorMessage())
}
return nil
}
return fmt.Errorf("%v: invalid error got %w", rg, err)
}
if getString(res.AcceptRanges) != "bytes" {
return fmt.Errorf("%v: expected accept ranges to be 'bytes', instead got %v", rg, getString(res.AcceptRanges))
}
if res.ContentLength == nil {
return fmt.Errorf("%v: expected non nil content-length", rg)
}
if *res.ContentLength != cLength {
return fmt.Errorf("%v: expected content-length to be %v, instead got %v", rg, cLength, *res.ContentLength)
}
if getString(res.ContentRange) != contentRange {
return fmt.Errorf("%v: expected content-range to be %v, instead got %v", rg, contentRange, getString(res.ContentRange))
}
return nil
}
// Reference server expectations for a 0-byte object.
for _, el := range []struct {
objRange string
contentRange string
contentLength int64
expectedErr bool
}{
{"bytes=abc", "", objLength, false},
{"bytes=a-z", "", objLength, false},
{"bytes=,", "", objLength, false},
{"bytes=0-0,1-2", "", objLength, false},
{"foo=0-1", "", objLength, false},
{"bytes=--1", "", objLength, false},
{"bytes=0--1", "", objLength, false},
{"bytes= -1", "", objLength, false},
{"bytes=0 -1", "", objLength, false},
{"bytes=-1", "", objLength, false}, // reference server returns no error, empty Content-Range
{"bytes=00-01", "", objLength, true}, // RequestedRangeNotSatisfiable
{"bytes=-0", "", 0, true},
{"bytes=0-0", "", 0, true},
{"bytes=0-", "", 0, true},
} {
if err := testRange(el.objRange, el.contentRange, el.contentLength, el.expectedErr); err != nil {
return err
}
}
return nil
})
}
func getObject_zero_len_with_range_helper(testName, obj string, s *S3Conf) error {
return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error {
objLength := int64(0)
res, err := putObjectWithData(objLength, &s3.PutObjectInput{
Bucket: &bucket,
Key: &obj,
}, s3client)
if err != nil {
return err
}
testGetObjectRange := func(rng, contentRange string, cLength int64, expData []byte, expErr error) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
defer cancel()
out, err := s3client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &bucket,
Key: &obj,
Range: &rng,
})
if err == nil && expErr != nil {
return fmt.Errorf("%v: expected err %v, instead got nil", rng, expErr)
}
if err != nil {
if expErr == nil {
return err
}
parsedErr, ok := expErr.(s3err.APIError)
if !ok {
return fmt.Errorf("invalid error type provided, expected s3err.APIError")
}
return checkApiErr(err, parsedErr)
}
if out.ContentLength == nil {
return fmt.Errorf("%v: expected non nil content-length", rng)
}
if *out.ContentLength != cLength {
return fmt.Errorf("%v: expected content-length to be %v, instead got %v", rng, cLength, *out.ContentLength)
}
if getString(out.AcceptRanges) != "bytes" {
return fmt.Errorf("%v: expected accept-ranges to be 'bytes', instead got %v", rng, getString(out.AcceptRanges))
}
if getString(out.ContentRange) != contentRange {
return fmt.Errorf("%v: expected content-range to be %v, instead got %v", rng, contentRange, getString(out.ContentRange))
}
data, err := io.ReadAll(out.Body)
if err != nil {
return fmt.Errorf("%v: read object data: %w", rng, err)
}
out.Body.Close()
if !isSameData(data, expData) {
return fmt.Errorf("%v: incorrect data retrieved", rng)
}
return nil
}
for _, el := range []struct {
rng string
contentRange string
cLength int64
expData []byte
expErr error
}{
{"bytes=abc", "", objLength, res.data, nil},
{"bytes=a-z", "", objLength, res.data, nil},
{"bytes=,", "", objLength, res.data, nil},
{"bytes=0-0,1-2", "", objLength, res.data, nil},
{"foo=0-1", "", objLength, res.data, nil},
{"bytes=--1", "", objLength, res.data, nil},
{"bytes=0--1", "", objLength, res.data, nil},
{"bytes= -1", "", objLength, res.data, nil},
{"bytes=0 -1", "", objLength, res.data, nil},
{"bytes=-1", "", objLength, res.data, nil},
// error (RequestedRangeNotSatisfiable)
{"bytes=00-01", "", objLength, nil, s3err.GetAPIError(s3err.ErrInvalidRange)},
{"bytes=-0", "", 0, nil, s3err.GetAPIError(s3err.ErrInvalidRange)},
{"bytes=0-0", "", 0, nil, s3err.GetAPIError(s3err.ErrInvalidRange)},
{"bytes=0-", "", 0, nil, s3err.GetAPIError(s3err.ErrInvalidRange)},
} {
if err := testGetObjectRange(el.rng, el.contentRange, el.cLength, el.expData, el.expErr); err != nil {
return err
}
}
return nil
})
}
func getInt32(ptr *int32) int32 {
if ptr == nil {
return 0
}
return *ptr
}
func putBucketCors(client *s3.Client, input *s3.PutBucketCorsInput) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutBucketCors(ctx, input)
cancel()
return err
}
func compareCorsConfig(expected, got []types.CORSRule) error {
if expected == nil && got == nil {
return nil
}
if got == nil {
return errors.New("nil CORS config")
}
if len(expected) != len(got) {
return fmt.Errorf("expected CORS rules length to be %v, instead got %v", len(expected), len(got))
}
for i, r := range expected {
rule := got[i]
if !slices.Equal(r.AllowedOrigins, rule.AllowedOrigins) {
return fmt.Errorf("expected the allowed origins to be %v, instead got %v", r.AllowedOrigins, rule.AllowedOrigins)
}
if !slices.Equal(r.AllowedMethods, rule.AllowedMethods) {
return fmt.Errorf("expected the allowed methods to be %v, instead got %v", r.AllowedMethods, rule.AllowedMethods)
}
if !slices.Equal(r.AllowedHeaders, rule.AllowedHeaders) {
return fmt.Errorf("expected the allowed headers to be %v, instead got %v", r.AllowedHeaders, rule.AllowedHeaders)
}
if !slices.Equal(r.ExposeHeaders, rule.ExposeHeaders) {
return fmt.Errorf("expected the allowed origins to be %v, instead got %v", r.ExposeHeaders, rule.ExposeHeaders)
}
if getInt32(r.MaxAgeSeconds) != getInt32(rule.MaxAgeSeconds) {
return fmt.Errorf("expected the max age seconds to be %v, instead got %v", getInt32(r.MaxAgeSeconds), getInt32(rule.MaxAgeSeconds))
}
if getString(r.ID) != getString(rule.ID) {
return fmt.Errorf("expected ID to be %v, instead got %v", getString(r.ID), getString(rule.ID))
}
}
return nil
}
type PreflightResult struct {
Origin string
Methods string
AllowHeaders string
ExposeHeaders string
MaxAge string
AllowCredentials string
Vary string
err error
}
func extractCORSHeaders(resp *http.Response) (*PreflightResult, error) {
if resp.StatusCode >= 400 {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var errResp smithy.GenericAPIError
err = xml.Unmarshal(body, &errResp)
if err != nil {
return nil, fmt.Errorf("unmarshal response body: %w", err)
}
return &PreflightResult{
err: &errResp,
}, nil
}
return &PreflightResult{
Origin: resp.Header.Get("Access-Control-Allow-Origin"),
Methods: resp.Header.Get("Access-Control-Allow-Methods"),
ExposeHeaders: resp.Header.Get("Access-Control-Expose-Headers"),
MaxAge: resp.Header.Get("Access-Control-Max-Age"),
AllowHeaders: resp.Header.Get("Access-Control-Allow-Headers"),
AllowCredentials: resp.Header.Get("Access-Control-Allow-Credentials"),
Vary: resp.Header.Get("Vary"),
}, nil
}
func makeOPTIONSRequest(s *S3Conf, bucket, origin, method string, headers string) (*PreflightResult, error) {
req, err := http.NewRequest(http.MethodOptions, fmt.Sprintf("%s/%s/object", s.endpoint, bucket), nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Add("Origin", origin)
req.Header.Add("Access-Control-Request-Method", method)
req.Header.Add("Access-Control-Request-Headers", headers)
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
return extractCORSHeaders(resp)
}
func comparePreflightResult(expected, got *PreflightResult) error {
if expected == nil {
return fmt.Errorf("nil expected preflight request result")
}
if got == nil {
return fmt.Errorf("expected the preflights result to be %v, instead got nil", *expected)
}
if expected.err != nil {
if got.err == nil {
return fmt.Errorf("expected %w error, instaed got nil", expected.err)
}
apiErr, ok := expected.err.(s3err.APIError)
if !ok {
return fmt.Errorf("expected s3err.APIError, instead got %w", expected.err)
}
return checkApiErr(got.err, apiErr)
}
if got.err != nil {
return fmt.Errorf("expected no error, instaed got %w", got.err)
}
if expected.Origin != got.Origin {
return fmt.Errorf("expected the origin to be %v, instead got %v", expected.Origin, got.Origin)
}
if expected.Methods != got.Methods {
return fmt.Errorf("expected the allowed methods to be %v, instead got %v", expected.Methods, got.Methods)
}
if expected.AllowHeaders != got.AllowHeaders {
return fmt.Errorf("expected the allow headers to be %v, instead got %v", expected.AllowHeaders, got.AllowHeaders)
}
if expected.ExposeHeaders != got.ExposeHeaders {
return fmt.Errorf("expected the expose headers to be %v, instead got %v", expected.ExposeHeaders, got.ExposeHeaders)
}
if expected.MaxAge != got.MaxAge {
return fmt.Errorf("expected the max age to be %v, instead got %v", expected.MaxAge, got.MaxAge)
}
if expected.AllowCredentials != got.AllowCredentials {
return fmt.Errorf("expected the allow credentials to be %v, instead got %v", expected.AllowCredentials, got.AllowCredentials)
}
if expected.Vary != got.Vary {
return fmt.Errorf("expected the Vary header to be %v, instead got %v", expected.Vary, got.Vary)
}
return nil
}
func testOPTIONSEdnpoint(s *S3Conf, bucket, origin, method string, headers string, expected *PreflightResult) error {
result, err := makeOPTIONSRequest(s, bucket, origin, method, headers)
if err != nil {
return err
}
return comparePreflightResult(expected, result)
}
func calculateEtag(data []byte) (string, error) {
h := md5.New()
_, err := h.Write(data)
if err != nil {
return "", err
}
dataSum := h.Sum(nil)
return fmt.Sprintf("\"%s\"", hex.EncodeToString(dataSum[:])), nil
}
func sprintBuckets(buckets []types.Bucket) string {
if len(buckets) == 0 {
return ""
}
names := make([]string, len(buckets))
for i, bucket := range buckets {
names[i] = *bucket.Name
}
return strings.Join(names, ",")
}
func sprintPrefixes(cpfx []types.CommonPrefix) string {
if len(cpfx) == 0 {
return ""
}
names := make([]string, len(cpfx))
for i, pfx := range cpfx {
names[i] = *pfx.Prefix
}
return strings.Join(names, ",")
}
func sprintVersions(objects []types.ObjectVersion) string {
if len(objects) == 0 {
return ""
}
names := make([]string, len(objects))
for i, obj := range objects {
names[i] = fmt.Sprintf("%v/%v", *obj.Key, obj.VersionId)
}
return strings.Join(names, ",")
}
// objToDelete represents the metadata of an object that needs to be deleted.
// It holds details like the key, version, and legal/compliance lock flags.
type objToDelete struct {
key string // Object key (name) in the bucket
versionId string // Specific object version ID
removeLegalHold bool // Whether to remove legal hold before deletion
removeOnlyLeglHold bool // Whether to only remove legal hold, without deletion
isCompliance bool // Whether the object is under Compliance mode retention
}
// Worker and retry configuration for deleting locked objects
const (
maxDelObjWorkers int64 = 20 // Maximum number of concurrent delete workers
maxRetryAttempts int = 3 // Maximum retries for object deletion
lockWaitTime time.Duration = time.Second * 3 // Wait time for lock expiration before retrying delete
)
// cleanupLockedObjects removes objects from a bucket that may be protected by
// Object Lock (legal hold or retention).
// It handles both Governance and Compliance retention modes and retries deletions
// when necessary.
func cleanupLockedObjects(client *s3.Client, bucket string, objs []objToDelete) error {
eg, ctx := errgroup.WithContext(context.Background())
// Semaphore to limit the number of concurrent workers
sem := semaphore.NewWeighted(maxDelObjWorkers)
for _, obj := range objs {
// Acquire worker slot before processing an object
if err := sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("failed to acquire worker space: %w", err)
}
defer sem.Release(1)
eg.Go(func() error {
// Remove legal hold if required
if obj.removeLegalHold || obj.removeOnlyLeglHold {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutObjectLegalHold(ctx, &s3.PutObjectLegalHoldInput{
Bucket: &bucket,
Key: &obj.key,
VersionId: getPtr(obj.versionId),
LegalHold: &types.ObjectLockLegalHold{
Status: types.ObjectLockLegalHoldStatusOff, // Disable legal hold
},
})
cancel()
// If object was already deleted, ignore the error
if errors.Is(err, s3err.GetAPIError(s3err.ErrNoSuchKey)) {
return nil
}
if err != nil {
return err
}
// If only the legal hold needs to be removed, stop here
if obj.removeOnlyLeglHold {
return nil
}
}
// Apply temporary retention policy to allow deletion
// RetainUntilDate is set a few seconds in the future to handle network delays
retDate := time.Now().Add(lockWaitTime)
mode := types.ObjectLockRetentionModeGovernance
if obj.isCompliance {
mode = types.ObjectLockRetentionModeCompliance
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutObjectRetention(ctx, &s3.PutObjectRetentionInput{
Bucket: &bucket,
Key: &obj.key,
VersionId: getPtr(obj.versionId),
Retention: &types.ObjectLockRetention{
Mode: mode,
RetainUntilDate: &retDate,
},
})
cancel()
// If object was already deleted, ignore the error
if errors.Is(err, s3err.GetAPIError(s3err.ErrNoSuchKey)) {
return nil
}
if err != nil {
return err
}
// Wait until retention lock expires before attempting delete
time.Sleep(lockWaitTime)
// Return last error if all retries failed
return nil
})
}
// Wait for all goroutines to finish, return any error encountered
return eg.Wait()
}
type objectLockMode string
const (
objectLockModeLegalHold = "legal-hold"
objectLockModeGovernance = "governance"
objectLockModeCompliance = "compliance"
)
func lockObject(client *s3.Client, mode objectLockMode, bucket, object, versionId string) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
defer cancel()
var m types.ObjectLockRetentionMode
switch mode {
case objectLockModeLegalHold:
_, err := client.PutObjectLegalHold(ctx, &s3.PutObjectLegalHoldInput{
Bucket: &bucket,
Key: &object,
VersionId: getPtr(versionId),
LegalHold: &types.ObjectLockLegalHold{
Status: types.ObjectLockLegalHoldStatusOn,
},
})
return err
case objectLockModeCompliance:
m = types.ObjectLockRetentionModeCompliance
case objectLockModeGovernance:
m = types.ObjectLockRetentionModeGovernance
default:
return fmt.Errorf("invalid object lock mode: %s", mode)
}
date := time.Now().Add(time.Hour * 3)
_, err := client.PutObjectRetention(ctx, &s3.PutObjectRetentionInput{
Bucket: &bucket,
Key: &object,
VersionId: getPtr(versionId),
Retention: &types.ObjectLockRetention{
Mode: m,
RetainUntilDate: &date,
},
})
return err
}
func NewHasher(algo types.ChecksumAlgorithm) (hash.Hash, error) {
var hasher hash.Hash
switch algo {
case types.ChecksumAlgorithmMd5:
hasher = md5.New()
case types.ChecksumAlgorithmSha256:
hasher = sha256.New()
case types.ChecksumAlgorithmSha512:
hasher = sha512.New()
case types.ChecksumAlgorithmSha1:
hasher = sha1.New()
case types.ChecksumAlgorithmCrc32:
hasher = crc32.NewIEEE()
case types.ChecksumAlgorithmCrc32c:
hasher = crc32.New(crc32.MakeTable(crc32.Castagnoli))
case types.ChecksumAlgorithmCrc64nvme:
hasher = crc64.New(crc64.MakeTable(bits.Reverse64(0xad93d23594c93659)))
case types.ChecksumAlgorithmXxhash64:
hasher = xxhash.New()
case types.ChecksumAlgorithmXxhash3:
hasher = xxh3.New()
case types.ChecksumAlgorithmXxhash128:
hasher = xxh3.New128()
default:
return nil, fmt.Errorf("unsupported hash algorithm: %s", algo)
}
return hasher, nil
}
func wrongChecksumForAlgorithm(algo types.ChecksumAlgorithm) (string, error) {
var size int
switch algo {
case types.ChecksumAlgorithmCrc32, types.ChecksumAlgorithmCrc32c:
size = 4
case types.ChecksumAlgorithmCrc64nvme, types.ChecksumAlgorithmXxhash64, types.ChecksumAlgorithmXxhash3:
size = 8
case types.ChecksumAlgorithmMd5, types.ChecksumAlgorithmXxhash128:
size = 16
case types.ChecksumAlgorithmSha1:
size = 20
case types.ChecksumAlgorithmSha256:
size = 32
case types.ChecksumAlgorithmSha512:
size = 64
default:
return "", fmt.Errorf("unsupported hash algorithm: %s", algo)
}
return base64.StdEncoding.EncodeToString(bytes.Repeat([]byte{0xff}, size)), nil
}
func processCompositeChecksum(hasher hash.Hash, checksum string) error {
data, err := base64.StdEncoding.DecodeString(checksum)
if err != nil {
return fmt.Errorf("base64 decode: %w", err)
}
_, err = hasher.Write(data)
if err != nil {
return fmt.Errorf("hash write: %w", err)
}
return nil
}
type mpinfo struct {
uploadId *string
parts []types.CompletedPart
}
func putBucketPolicy(client *s3.Client, bucket, policy string) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := client.PutBucketPolicy(ctx, &s3.PutBucketPolicyInput{
Bucket: &bucket,
Policy: &policy,
})
cancel()
return err
}
func sendSignedRequest(s *S3Conf, req *http.Request, cancel context.CancelFunc) (map[string]string, *APIErrorResponse, error) {
signer := v4.NewSigner()
signErr := signer.SignHTTP(req.Context(), aws.Credentials{AccessKeyID: s.awsID, SecretAccessKey: s.awsSecret}, req, "STREAMING-UNSIGNED-PAYLOAD-TRAILER", "s3", s.awsRegion, time.Now())
if signErr != nil {
cancel()
return nil, nil, fmt.Errorf("failed to sign the request: %w", signErr)
}
resp, err := s.httpClient.Do(req)
cancel()
if err != nil {
return nil, nil, fmt.Errorf("failed to send the request: %w", err)
}
if resp.StatusCode >= 300 {
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("failed to read the request body: %w", err)
}
var errResp APIErrorResponse
err = xml.Unmarshal(bodyBytes, &errResp)
if err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal response body: %w", err)
}
return nil, &errResp, nil
}
headers := map[string]string{}
for key, val := range resp.Header {
headers[strings.ToLower(key)] = val[0]
}
return headers, nil, nil
}
func testUnsignedStreamingPayloadTrailerObjectPut(s *S3Conf, bucket, object string, body []byte, reqHeaders map[string]string) (map[string]string, *APIErrorResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, s.endpoint+"/"+bucket+"/"+object, bytes.NewReader(body))
if err != nil {
cancel()
return nil, nil, fmt.Errorf("failed to create a request: %w", err)
}
req.Header.Add("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
for key, val := range reqHeaders {
req.Header.Add(key, val)
}
return sendSignedRequest(s, req, cancel)
}
func testUnsignedStreamingPayloadTrailerUploadPart(s *S3Conf, bucket, object string, uploadId *string, body []byte, reqHeaders map[string]string) (map[string]string, *APIErrorResponse, error) {
if uploadId == nil {
return nil, nil, fmt.Errorf("empty upload id")
}
uri := fmt.Sprintf("%s/%s/%s?uploadId=%s&partNumber=%v", s.endpoint, bucket, object, *uploadId, 1)
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, uri, bytes.NewReader(body))
if err != nil {
cancel()
return nil, nil, fmt.Errorf("failed to create a request: %w", err)
}
req.Header.Add("x-amz-content-sha256", "STREAMING-UNSIGNED-PAYLOAD-TRAILER")
for key, val := range reqHeaders {
req.Header.Add(key, val)
}
return sendSignedRequest(s, req, cancel)
}
// constructUnsignedPaylod constructs an unsigned streaming upload payload
// and returns the decoded content length and the payload
func constructUnsignedPaylod(chunkSizes ...int64) (int64, []byte, error) {
var cLength int64
buffer := bytes.NewBuffer([]byte{})
for _, chunkSize := range chunkSizes {
cLength += chunkSize
_, err := buffer.WriteString(fmt.Sprintf("%x\r\n", chunkSize))
if err != nil {
return 0, nil, err
}
_, err = buffer.WriteString(strings.Repeat("a", int(chunkSize)))
if err != nil {
return 0, nil, err
}
_, err = buffer.WriteString("\r\n")
if err != nil {
return 0, nil, err
}
}
return cLength, buffer.Bytes(), nil
}
type signedReqCfg struct {
headers map[string]string
chunkSize int64
modifFrom *int
modifTo *int
modifPayload []byte
trailingChecksum *string
isTrailer bool
}
type signedReqOpt func(*signedReqCfg)
func withCustomHeaders(h map[string]string) signedReqOpt {
return func(src *signedReqCfg) { src.headers = h }
}
func withChunkSize(s int64) signedReqOpt {
return func(src *signedReqCfg) { src.chunkSize = s }
}
func withModifyPayload(from int, to int, p []byte) signedReqOpt {
return func(src *signedReqCfg) {
src.modifPayload = p
src.modifFrom = &from
src.modifTo = &to
}
}
func withTrailingChecksum(checksum string) signedReqOpt {
return func(src *signedReqCfg) {
src.trailingChecksum = &checksum
src.isTrailer = true
}
}
func testSignedStreamingObjectPut(s *S3Conf, bucket, object string, payload []byte, opts ...signedReqOpt) (map[string]string, *APIErrorResponse, error) {
cfg := &signedReqCfg{
chunkSize: 8192, // minimal valid chunk size
}
for _, opt := range opts {
opt(cfg)
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
// create a request with no body
req, err := http.NewRequestWithContext(ctx, http.MethodPut, fmt.Sprintf("%s/%s/%s", s.endpoint, bucket, object), nil)
if err != nil {
return nil, nil, cancelAndError(fmt.Errorf("failed to create a request: %w", err), cancel)
}
var payloadOffset int64
var trailerLength int
// any planned modification which is going to affect the
// Content-Length header value
if cfg.modifFrom != nil && cfg.modifTo != nil {
diff := len(cfg.modifPayload) - *cfg.modifTo + *cfg.modifFrom
payloadOffset = int64(diff)
}
if cfg.isTrailer {
trailerLength = len(*cfg.trailingChecksum)
}
// precalculated the Content-Length header to correctly sign the request
req.ContentLength = calculateSignedReqContentLength(int64(len(payload)), cfg.chunkSize, payloadOffset, cfg.isTrailer, int64(trailerLength))
sha256Header := "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
if cfg.isTrailer {
sha256Header = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER"
}
req.Header.Set("x-amz-decoded-content-length", fmt.Sprint(len(payload)))
req.Header.Set("x-amz-content-sha256", sha256Header)
// set custom request headers
for key, val := range cfg.headers {
req.Header.Set(key, val)
}
signer := v4.NewSigner()
signingTime := time.Now().UTC()
// sign the request
err = signer.SignHTTP(ctx, aws.Credentials{AccessKeyID: s.awsID, SecretAccessKey: s.awsSecret}, req, sha256Header, "s3", s.awsRegion, signingTime)
if err != nil {
return nil, nil, cancelAndError(fmt.Errorf("failed to sign the request: %w", err), cancel)
}
// extract the seed signature
seedSignature, err := extractSignature(req)
if err != nil {
return nil, nil, cancelAndError(fmt.Errorf("failed to extract seed signature: %w", err), cancel)
}
// initialize v4 stream signed
streamSigner := v4.NewStreamSigner(aws.Credentials{AccessKeyID: s.awsID, SecretAccessKey: s.awsSecret}, "s3", s.awsRegion, seedSignature)
// create the signed payload
body, err := constructSignedStreamingPayload(ctx, streamSigner, signingTime, payload, cfg.chunkSize, cfg.trailingChecksum, s.awsRegion, s.awsSecret)
if err != nil {
return nil, nil, cancelAndError(fmt.Errorf("failed to encode req body: %w", err), cancel)
}
// overwrite body bytes by configuration
if cfg.modifFrom != nil && cfg.modifTo != nil {
body, err = replaceRange(body, cfg.modifPayload, *cfg.modifFrom, *cfg.modifTo)
if err != nil {
return nil, nil, cancelAndError(fmt.Errorf("failed replace body bytes: %w", err), cancel)
}
}
// assign req.Body and req.GetBody for the http client
// to handle the request
req.Body = io.NopCloser(bytes.NewReader(body))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(body)), nil
}
// send the request
resp, err := s.httpClient.Do(req)
cancel()
if err != nil {
return nil, nil, fmt.Errorf("failed to send the request: %w", err)
}
if resp.StatusCode >= 300 {
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("failed to read the response body: %w", err)
}
var errResp APIErrorResponse
err = xml.Unmarshal(bodyBytes, &errResp)
if err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal response body: %w", err)
}
return nil, &errResp, nil
}
headers := map[string]string{}
for key, val := range resp.Header {
headers[strings.ToLower(key)] = val[0]
}
return headers, nil, nil
}
func cancelAndError(err error, cancel context.CancelFunc) error {
cancel()
return err
}
const (
chunkSigHdrLength int64 = 81
trailerSigLength int64 = 88
)
// calculateSignedReqContentLength calculates the value of `Content-Length` header
// sizeOffset marks any planned changes on the body, which will affect the size
func calculateSignedReqContentLength(decPayloadSize int64, chunkSize int64, sizeOffset int64, withTrailer bool, trailerLength int64) int64 {
payloadSize := decPayloadSize
var chunkHeadersLength int64
if withTrailer {
chunkHeadersLength += trailerLength + 4 + trailerSigLength
}
// special case when chunk size is greater or equal than decoded content length
if chunkSize >= decPayloadSize {
chSizeLgth := len(fmt.Sprintf("%x", decPayloadSize))
return decPayloadSize + sizeOffset + int64(chSizeLgth) + 2*chunkSigHdrLength + 9 + chunkHeadersLength
}
for {
if payloadSize == 0 {
chunkHeadersLength += chunkSigHdrLength + 5
break
}
if payloadSize < chunkSize {
chunkHeadersLength += 2*chunkSigHdrLength + 9 + int64(len(fmt.Sprintf("%x", payloadSize)))
break
}
chSizeLgth := len(fmt.Sprintf("%x", chunkSize))
chunkHeadersLength += int64(chSizeLgth) + chunkSigHdrLength + 4
payloadSize -= chunkSize
}
return chunkHeadersLength + decPayloadSize + sizeOffset
}
// constructSignedStreamingPayload creates chunk encoded payload with signatures.
func constructSignedStreamingPayload(ctx context.Context, signer *v4.StreamSigner, signingTime time.Time, payload []byte, chunkSize int64, trailer *string, region, secret string) ([]byte, error) {
buf := bytes.NewBuffer(nil)
payloadLen := int64(len(payload))
if chunkSize > payloadLen {
chunkSize = payloadLen
}
for i := int64(0); i < payloadLen; i += chunkSize {
if i+chunkSize > payloadLen {
offset := payloadLen - i
sig, err := signer.GetSignature(ctx, nil, payload[i:i+offset], signingTime)
if err != nil {
return nil, err
}
_, err = buf.WriteString(fmt.Sprintf("%x;chunk-signature=%x\r\n%s\r\n", offset, sig, payload[i:i+offset]))
if err != nil {
return nil, err
}
break
}
sig, err := signer.GetSignature(ctx, nil, payload[i:i+chunkSize], signingTime)
if err != nil {
return nil, err
}
_, err = buf.WriteString(fmt.Sprintf("%x;chunk-signature=%x\r\n%s\r\n", chunkSize, sig, payload[i:i+chunkSize]))
if err != nil {
return nil, err
}
}
sig, err := signer.GetSignature(ctx, nil, nil, signingTime)
if err != nil {
return nil, err
}
if trailer != nil {
_, err = buf.WriteString(fmt.Sprintf("0;chunk-signature=%x\r\n", sig))
if err != nil {
return nil, err
}
sigKey := getSigningKey(secret, signingTime.Format("20060102"), region)
trailerSig, err := getAWS4StreamingTrailer(sigKey, sig, signingTime, region, *trailer)
if err != nil {
return nil, err
}
_, err = buf.WriteString(fmt.Sprintf("%s\r\nx-amz-trailer-signature:%s\r\n\r\n", *trailer, trailerSig))
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
_, err = buf.WriteString(fmt.Sprintf("0;chunk-signature=%x\r\n\r\n", sig))
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// extractSignature extracts the signature from Authorization header
func extractSignature(req *http.Request) ([]byte, error) {
const key = "Signature="
authHdr := req.Header.Get("Authorization")
_, after, ok := strings.Cut(authHdr, key)
if !ok {
return nil, errors.New("signature not found")
}
sig := after
return hex.DecodeString(sig)
}
// replaceRange replaces dst[start:end] with src and returns the modified slice.
// Used for custom overwrite of request payload bytes.
func replaceRange(dst, src []byte, start, end int) ([]byte, error) {
if start < 0 || end < start || end > len(dst) {
return nil, fmt.Errorf("invalid start/end indexes")
}
newLen := len(dst) - (end - start) + len(src)
// Fast path: reuse dst capacity if possible
if cap(dst) >= newLen {
// Extend or shrink dst
dst = dst[:newLen]
// Move the tail if sizes differ
copy(dst[start+len(src):], dst[end:])
// Copy replacement
copy(dst[start:], src)
return dst, nil
}
// Fallback: allocate new slice
out := make([]byte, newLen)
copy(out, dst[:start])
copy(out[start:], src)
copy(out[start+len(src):], dst[end:])
return out, nil
}
func getAWS4StreamingTrailer(
signingKey,
lastSignature []byte,
signingTime time.Time,
awsRegion,
trailer string,
) (string, error) {
// yyyyMMdd
yearMonthDay := signingTime.Format("20060102")
// ISO8601 basic format: yyyyMMdd'T'HHmmss'Z'
currentDateTime := signingTime.Format("20060102T150405Z")
// <date>/<region>/<service>/aws4_request
serviceString := fmt.Sprintf(
"%s/%s/s3/aws4_request",
yearMonthDay,
awsRegion,
)
// Trailer must be newline-terminated for hashing/signing
trailerWithNL := trailer + "\n"
// Hash of trailer
trailerHash := sha256.Sum256([]byte(trailerWithNL))
trailerHashHex := hex.EncodeToString(trailerHash[:])
// String-to-sign prefix
stringToSignPrefix := fmt.Sprintf(
"%s\n%s\n%s",
"AWS4-HMAC-SHA256-TRAILER",
currentDateTime,
serviceString,
)
// Full string-to-sign
stringToSign := fmt.Sprintf(
"%s\n%x\n%s",
stringToSignPrefix,
lastSignature,
trailerHashHex,
)
// Final trailer signature
finalSignature := hex.EncodeToString(
hmacSHA256(signingKey, stringToSign),
)
return finalSignature, nil
}
func hmacSHA256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}
func getSigningKey(secret, yearMonthDay, region string) []byte {
dateKey := hmacSHA256([]byte("AWS4"+secret), yearMonthDay)
dateRegionKey := hmacSHA256(dateKey, region)
dateRegionServiceKey := hmacSHA256(dateRegionKey, "s3")
return hmacSHA256(dateRegionServiceKey, "aws4_request")
}
// buildPostObjectBody builds a multipart/form-data body for a POST object request.
// Fields are written first (in the order provided)
// Returns (body bytes, boundary string, error).
func buildPostObjectBody(fields, extraFields map[string]string, fileContent []byte) ([]byte, string, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
for key, value := range fields {
v, ok := extraFields[key]
if ok {
delete(extraFields, key)
if v == "" {
continue
}
value = v
}
if err := w.WriteField(key, value); err != nil {
return nil, "", err
}
}
for key, value := range extraFields {
if err := w.WriteField(key, value); err != nil {
return nil, "", err
}
}
fw, err := w.CreateFormFile("file", "upload.bin")
if err != nil {
return nil, "", err
}
if _, err = fw.Write(fileContent); err != nil {
return nil, "", err
}
if err := w.Close(); err != nil {
return nil, "", err
}
return buf.Bytes(), w.Boundary(), nil
}
// encodePostPolicy base64-encodes a POST policy JSON document.
// optionally it omits some fields based on the input configuration
func encodePostPolicy(conditions []any, expiration time.Time, fields map[string]string, omitConditions map[string]struct{}) (string, error) {
for key, value := range fields {
if _, ok := omitConditions[key]; ok {
continue
}
conditions = append(conditions, map[string]string{
key: value,
})
}
if expiration.IsZero() {
expiration = time.Now().UTC().Add(15 * time.Minute)
}
policy := map[string]any{
"expiration": expiration.UTC().Format(time.RFC3339),
"conditions": conditions,
}
b, err := json.Marshal(policy)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
// signPostPolicy computes the AWS SigV4 HMAC-SHA256 signature over a base64-encoded
// POST policy document.
func signPostPolicy(policyB64, dateShort, region, secret string) string {
signingKey := getSigningKey(secret, dateShort, region)
sig := hmacSHA256(signingKey, policyB64)
return hex.EncodeToString(sig)
}
// buildSignedPostFields returns the complete set of AWS authentication form fields
// for a signed POST object request.
// - key
// - x-amz-algorithm
// - x-amz-credential
// - x-amz-date
func buildSignedPostFields(bucket, key, access, region string, date time.Time) map[string]string {
dateShort := date.Format("20060102")
dateLong := date.Format(iso8601Format)
credential := fmt.Sprintf("%s/%s/%s/s3/aws4_request", access, dateShort, region)
return map[string]string{
"bucket": bucket,
"key": key,
"x-amz-algorithm": "AWS4-HMAC-SHA256",
"x-amz-credential": credential,
"x-amz-date": dateLong,
}
}
type PostRequestConfig struct {
bucket string
key string
access string
secret string
region string
extraFields map[string]string
fileContent []byte
rawPolicy *string
date time.Time
policyExpiration time.Time
policyConditions []any
omitPolicyConditions map[string]struct{}
s3Conf *S3Conf
}
// sendPostObject sends a POST multipart/form-data request to /{bucket}.
// Returns the raw *http.Response for flexible per-test assertions.
func sendPostObject(input PostRequestConfig) (*http.Response, error) {
req, _, err := newPostObjectRequest(input)
if err != nil {
return nil, err
}
return input.s3Conf.httpClient.Do(req)
}
func newPostObjectRequest(input PostRequestConfig) (*http.Request, map[string]string, error) {
if input.date.IsZero() {
input.date = time.Now().UTC()
}
if input.policyExpiration.IsZero() {
input.policyExpiration = time.Now().UTC().AddDate(0, 0, 1)
}
if input.access == "" {
input.access = input.s3Conf.awsID
}
if input.secret == "" {
input.secret = input.s3Conf.awsSecret
}
if input.region == "" {
input.region = input.s3Conf.awsRegion
}
fields := buildSignedPostFields(input.bucket, input.key, input.access, input.region, input.date)
var policy string
if input.rawPolicy == nil {
var err error
policy, err = encodePostPolicy(input.policyConditions, input.policyExpiration, fields, input.omitPolicyConditions)
if err != nil {
return nil, nil, err
}
} else {
policy = *input.rawPolicy
}
fields["policy"] = policy
fields["x-amz-signature"] = signPostPolicy(policy, input.date.Format("20060102"), input.region, input.secret)
body, boundary, err := buildPostObjectBody(fields, input.extraFields, input.fileContent)
if err != nil {
return nil, nil, err
}
endpoint := fmt.Sprintf("%s/%s", input.s3Conf.endpoint, input.bucket)
if input.s3Conf.hostStyle {
u, err := url.Parse(input.s3Conf.endpoint)
if err != nil {
return nil, nil, err
}
u.Host = input.bucket + "." + u.Host
endpoint = u.String()
}
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
req.Header.Set("Content-Type", fmt.Sprintf("multipart/form-data; boundary=%s", boundary))
req.ContentLength = int64(len(body))
return req, fields, nil
}
func getEtagBytes(etag string) ([]byte, error) {
return hex.DecodeString(strings.ReplaceAll(etag, string('"'), ""))
}
func md5String(data []byte) string {
sum := md5.Sum(data)
return hex.EncodeToString(sum[:])
}
func base64ToHexString(s string) string {
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return ""
}
return hex.EncodeToString(data)
}
func hexBytes(s string) string {
data := []byte(s)
parts := make([]string, len(data))
for i, b := range data {
parts[i] = fmt.Sprintf("%02x", b)
}
return strings.Join(parts, " ")
}