Files
versitygw/s3api/utils/utils_test.go
niksis02 ebdda06633 fix: adds BadDigest error for incorrect Content-Md5 s
Closes #1525

* Adds validation for the `Content-MD5` header.
  * If the header value is invalid, the gateway now returns an `InvalidDigest` error.
  * If the value is valid but does not match the payload, it returns a `BadDigest` error.
* Adds integration test cases for `PutBucketCors` with `Content-MD5`.
2025-09-19 19:51:23 +04:00

950 lines
21 KiB
Go

// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package utils
import (
"bytes"
"encoding/xml"
"errors"
"math/rand"
"net/http"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"github.com/versity/versitygw/backend"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
func TestCreateHttpRequestFromCtx(t *testing.T) {
type args struct {
ctx *fiber.Ctx
}
app := fiber.New()
// Expected output, Case 1
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
req := ctx.Request()
request, _ := http.NewRequest(string(req.Header.Method()), req.URI().String(), bytes.NewReader(req.Body()))
// Case 2
ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{})
req2 := ctx2.Request()
req2.Header.Add("X-Amz-Mfa", "Some valid Mfa")
request2, _ := http.NewRequest(string(req2.Header.Method()), req2.URI().String(), bytes.NewReader(req2.Body()))
request2.Header.Add("X-Amz-Mfa", "Some valid Mfa")
tests := []struct {
name string
args args
want *http.Request
wantErr bool
hdrs []string
}{
{
name: "Success-response",
args: args{
ctx: ctx,
},
want: request,
wantErr: false,
hdrs: []string{},
},
{
name: "Success-response-With-Headers",
args: args{
ctx: ctx2,
},
want: request2,
wantErr: false,
hdrs: []string{"X-Amz-Mfa"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := createHttpRequestFromCtx(tt.args.ctx, tt.hdrs, 0, true)
if (err != nil) != tt.wantErr {
t.Errorf("CreateHttpRequestFromCtx() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Header, tt.want.Header) {
t.Errorf("CreateHttpRequestFromCtx() got = %v, want %v", got, tt.want)
}
})
}
}
func TestGetUserMetaData(t *testing.T) {
type args struct {
headers *fasthttp.RequestHeader
}
app := fiber.New()
// Case 1
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
req := ctx.Request()
tests := []struct {
name string
args args
wantMetadata map[string]string
}{
{
name: "Success-empty-response",
args: args{
headers: &req.Header,
},
wantMetadata: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotMetadata := GetUserMetaData(tt.args.headers); !reflect.DeepEqual(gotMetadata, tt.wantMetadata) {
t.Errorf("GetUserMetaData() = %v, want %v", gotMetadata, tt.wantMetadata)
}
})
}
}
func Test_includeHeader(t *testing.T) {
type args struct {
hdr string
signedHdrs []string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "include-header-falsy-case",
args: args{
hdr: "Content-Type",
signedHdrs: []string{"X-Amz-Acl", "Content-Encoding"},
},
want: false,
},
{
name: "include-header-falsy-case",
args: args{
hdr: "Content-Type",
signedHdrs: []string{"X-Amz-Acl", "Content-Type"},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := includeHeader(tt.args.hdr, tt.args.signedHdrs); got != tt.want {
t.Errorf("includeHeader() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsValidBucketName(t *testing.T) {
type args struct {
bucket string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "IsValidBucketName-short-name",
args: args{
bucket: "a",
},
want: false,
},
{
name: "IsValidBucketName-start-with-hyphen",
args: args{
bucket: "-bucket",
},
want: false,
},
{
name: "IsValidBucketName-start-with-dot",
args: args{
bucket: ".bucket",
},
want: false,
},
{
name: "IsValidBucketName-contain-invalid-character",
args: args{
bucket: "my@bucket",
},
want: false,
},
{
name: "IsValidBucketName-end-with-hyphen",
args: args{
bucket: "bucket-",
},
want: false,
},
{
name: "IsValidBucketName-end-with-dot",
args: args{
bucket: "bucket.",
},
want: false,
},
{
name: "IsValidBucketName-valid-bucket-name",
args: args{
bucket: "my-bucket",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidBucketName(tt.args.bucket); got != tt.want {
t.Errorf("IsValidBucketName() = %v, want %v", got, tt.want)
}
})
}
}
func TestParseUint(t *testing.T) {
type args struct {
str string
}
tests := []struct {
name string
args args
want int32
wantErr bool
}{
{
name: "Parse-uint-empty-string",
args: args{
str: "",
},
want: 1000,
wantErr: false,
},
{
name: "Parse-uint-invalid-number-string",
args: args{
str: "bla",
},
want: 1000,
wantErr: true,
},
{
name: "Parse-uint-invalid-negative-number",
args: args{
str: "-5",
},
want: 1000,
wantErr: true,
},
{
name: "Parse-uint-success",
args: args{
str: "23",
},
want: 23,
wantErr: false,
},
{
name: "Parse-uint-greater-than-1000",
args: args{
str: "25000000",
},
want: 1000,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseUint(tt.args.str)
if (err != nil) != tt.wantErr {
t.Errorf("ParseMaxKeys() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ParseMaxKeys() = %v, want %v", got, tt.want)
}
})
}
}
func TestFilterObjectAttributes(t *testing.T) {
type args struct {
attrs map[s3response.ObjectAttributes]struct{}
output s3response.GetObjectAttributesResponse
}
etag, objSize := "etag", int64(3222)
delMarker := true
tests := []struct {
name string
args args
want s3response.GetObjectAttributesResponse
}{
{
name: "keep only ETag",
args: args{
attrs: map[s3response.ObjectAttributes]struct{}{
s3response.ObjectAttributesEtag: {},
},
output: s3response.GetObjectAttributesResponse{
ObjectSize: &objSize,
ETag: &etag,
},
},
want: s3response.GetObjectAttributesResponse{ETag: &etag},
},
{
name: "keep multiple props",
args: args{
attrs: map[s3response.ObjectAttributes]struct{}{
s3response.ObjectAttributesEtag: {},
s3response.ObjectAttributesObjectSize: {},
s3response.ObjectAttributesStorageClass: {},
},
output: s3response.GetObjectAttributesResponse{
ObjectSize: &objSize,
ETag: &etag,
ObjectParts: &s3response.ObjectParts{},
VersionId: &etag,
},
},
want: s3response.GetObjectAttributesResponse{
ETag: &etag,
ObjectSize: &objSize,
},
},
{
name: "make sure LastModified, DeleteMarker and VersionId are removed",
args: args{
attrs: map[s3response.ObjectAttributes]struct{}{
s3response.ObjectAttributesEtag: {},
},
output: s3response.GetObjectAttributesResponse{
ObjectSize: &objSize,
ETag: &etag,
ObjectParts: &s3response.ObjectParts{},
VersionId: &etag,
LastModified: backend.GetTimePtr(time.Now()),
DeleteMarker: &delMarker,
},
},
want: s3response.GetObjectAttributesResponse{
ETag: &etag,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := FilterObjectAttributes(tt.args.attrs, tt.args.output); !reflect.DeepEqual(got, tt.want) {
t.Errorf("FilterObjectAttributes() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsValidOwnership(t *testing.T) {
type args struct {
val types.ObjectOwnership
}
tests := []struct {
name string
args args
want bool
}{
{
name: "valid-BucketOwnerEnforced",
args: args{
val: types.ObjectOwnershipBucketOwnerEnforced,
},
want: true,
},
{
name: "valid-BucketOwnerPreferred",
args: args{
val: types.ObjectOwnershipBucketOwnerPreferred,
},
want: true,
},
{
name: "valid-ObjectWriter",
args: args{
val: types.ObjectOwnershipObjectWriter,
},
want: true,
},
{
name: "invalid_value",
args: args{
val: types.ObjectOwnership("invalid_value"),
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidOwnership(tt.args.val); got != tt.want {
t.Errorf("IsValidOwnership() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsChecksumAlgorithmValid(t *testing.T) {
type args struct {
alg types.ChecksumAlgorithm
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "empty",
args: args{
alg: "",
},
wantErr: false,
},
{
name: "crc32",
args: args{
alg: types.ChecksumAlgorithmCrc32,
},
wantErr: false,
},
{
name: "crc32c",
args: args{
alg: types.ChecksumAlgorithmCrc32c,
},
wantErr: false,
},
{
name: "sha1",
args: args{
alg: types.ChecksumAlgorithmSha1,
},
wantErr: false,
},
{
name: "sha256",
args: args{
alg: types.ChecksumAlgorithmSha256,
},
wantErr: false,
},
{
name: "crc64nvme",
args: args{
alg: types.ChecksumAlgorithmCrc64nvme,
},
wantErr: false,
},
{
name: "invalid",
args: args{
alg: types.ChecksumAlgorithm("invalid_checksum_algorithm"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := IsChecksumAlgorithmValid(tt.args.alg); (err != nil) != tt.wantErr {
t.Errorf("IsChecksumAlgorithmValid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestIsValidChecksum(t *testing.T) {
type args struct {
checksum string
algorithm types.ChecksumAlgorithm
}
tests := []struct {
name string
args args
want bool
}{
{
name: "invalid-base64",
args: args{
checksum: "invalid_base64_string",
algorithm: types.ChecksumAlgorithmCrc32,
},
want: false,
},
{
name: "invalid-crc32",
args: args{
checksum: "YXNkZmFzZGZhc2Rm",
algorithm: types.ChecksumAlgorithmCrc32,
},
want: false,
},
{
name: "valid-crc32",
args: args{
checksum: "ww2FVQ==",
algorithm: types.ChecksumAlgorithmCrc32,
},
want: true,
},
{
name: "invalid-crc32c",
args: args{
checksum: "Zmdoa2doZmtnZmhr",
algorithm: types.ChecksumAlgorithmCrc32c,
},
want: false,
},
{
name: "valid-crc32c",
args: args{
checksum: "DOsb4w==",
algorithm: types.ChecksumAlgorithmCrc32c,
},
want: true,
},
{
name: "invalid-sha1",
args: args{
checksum: "YXNkZmFzZGZhc2RmYXNkZnNhZGZzYWRm",
algorithm: types.ChecksumAlgorithmSha1,
},
want: false,
},
{
name: "valid-sha1",
args: args{
checksum: "L4q6V59Zcwn12wyLIytoE2c1ugk=",
algorithm: types.ChecksumAlgorithmSha1,
},
want: true,
},
{
name: "invalid-sha256",
args: args{
checksum: "Zmdoa2doZmtnZmhrYXNkZmFzZGZhc2RmZHNmYXNkZg==",
algorithm: types.ChecksumAlgorithmSha256,
},
want: false,
},
{
name: "valid-sha256",
args: args{
checksum: "d1SPCd/kZ2rAzbbLUC0n/bEaOSx70FNbXbIqoIxKuPY=",
algorithm: types.ChecksumAlgorithmSha256,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidChecksum(tt.args.checksum, tt.args.algorithm); got != tt.want {
t.Errorf("IsValidChecksum() = %v, want %v", got, tt.want)
}
})
}
}
func TestIsChecksumTypeValid(t *testing.T) {
type args struct {
t types.ChecksumType
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid_FULL_OBJECT",
args: args{
t: types.ChecksumTypeFullObject,
},
wantErr: false,
},
{
name: "valid_COMPOSITE",
args: args{
t: types.ChecksumTypeComposite,
},
wantErr: false,
},
{
name: "invalid",
args: args{
t: types.ChecksumType("invalid_checksum_type"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := IsChecksumTypeValid(tt.args.t); (err != nil) != tt.wantErr {
t.Errorf("IsChecksumTypeValid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_checkChecksumTypeAndAlgo(t *testing.T) {
type args struct {
algo types.ChecksumAlgorithm
t types.ChecksumType
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "full_object-crc32",
args: args{
algo: types.ChecksumAlgorithmCrc32,
t: types.ChecksumTypeFullObject,
},
wantErr: false,
},
{
name: "full_object-crc32c",
args: args{
algo: types.ChecksumAlgorithmCrc32c,
t: types.ChecksumTypeFullObject,
},
wantErr: false,
},
{
name: "full_object-sha1",
args: args{
algo: types.ChecksumAlgorithmSha1,
t: types.ChecksumTypeFullObject,
},
wantErr: true,
},
{
name: "full_object-sha256",
args: args{
algo: types.ChecksumAlgorithmSha1,
t: types.ChecksumTypeFullObject,
},
wantErr: true,
},
{
name: "full_object-crc64nvme",
args: args{
algo: types.ChecksumAlgorithmCrc64nvme,
t: types.ChecksumTypeFullObject,
},
wantErr: false,
},
{
name: "full_object-crc32",
args: args{
algo: types.ChecksumAlgorithmCrc32,
t: types.ChecksumTypeFullObject,
},
wantErr: false,
},
{
name: "composite-crc32",
args: args{
algo: types.ChecksumAlgorithmCrc32,
t: types.ChecksumTypeComposite,
},
wantErr: false,
},
{
name: "composite-crc32c",
args: args{
algo: types.ChecksumAlgorithmCrc32c,
t: types.ChecksumTypeComposite,
},
wantErr: false,
},
{
name: "composite-sha1",
args: args{
algo: types.ChecksumAlgorithmSha1,
t: types.ChecksumTypeComposite,
},
wantErr: false,
},
{
name: "composite-sha256",
args: args{
algo: types.ChecksumAlgorithmSha256,
t: types.ChecksumTypeComposite,
},
wantErr: false,
},
{
name: "composite-crc64nvme",
args: args{
algo: types.ChecksumAlgorithmCrc64nvme,
t: types.ChecksumTypeComposite,
},
wantErr: true,
},
{
name: "composite-empty",
args: args{
t: types.ChecksumTypeComposite,
},
wantErr: true,
},
{
name: "full_object-empty",
args: args{
t: types.ChecksumTypeFullObject,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := checkChecksumTypeAndAlgo(tt.args.algo, tt.args.t); (err != nil) != tt.wantErr {
t.Errorf("checkChecksumTypeAndAlgo() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestParseTagging(t *testing.T) {
genRandStr := func(lgth int) string {
b := make([]byte, lgth)
for i := range b {
b[i] = byte(rand.Intn(95) + 32) // 126 - 32 + 1 = 95 printable characters
}
return string(b)
}
getTagSet := func(lgth int) s3response.TaggingInput {
res := s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{},
},
}
for i := 0; i < lgth; i++ {
res.TagSet.Tags = append(res.TagSet.Tags, s3response.Tag{
Key: genRandStr(10),
Value: genRandStr(20),
})
}
return res
}
type args struct {
data s3response.TaggingInput
overrideXML []byte
limit TagLimit
}
tests := []struct {
name string
args args
want map[string]string
wantErr error
}{
{
name: "valid tags within limit",
args: args{
data: s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{
{Key: "key1", Value: "value1"},
{Key: "key2", Value: "value2"},
},
},
},
limit: TagLimitObject,
},
want: map[string]string{"key1": "value1", "key2": "value2"},
wantErr: nil,
},
{
name: "malformed XML",
args: args{
overrideXML: []byte("invalid xml"),
limit: TagLimitObject,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrMalformedXML),
},
{
name: "exceeds bucket tag limit",
args: args{
data: getTagSet(51),
limit: TagLimitBucket,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrBucketTaggingLimited),
},
{
name: "exceeds object tag limit",
args: args{
data: getTagSet(11),
limit: TagLimitObject,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrObjectTaggingLimited),
},
{
name: "invalid 0 length tag key",
args: args{
data: s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{{Key: "", Value: "value1"}},
},
},
limit: TagLimitObject,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrInvalidTagKey),
},
{
name: "invalid long tag key",
args: args{
data: s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{{Key: genRandStr(130), Value: "value1"}},
},
},
limit: TagLimitObject,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrInvalidTagKey),
},
{
name: "invalid long tag value",
args: args{
data: s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{{Key: "key", Value: genRandStr(257)}},
},
},
limit: TagLimitBucket,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrInvalidTagValue),
},
{
name: "duplicate tag key",
args: args{
data: s3response.TaggingInput{
TagSet: s3response.TagSet{
Tags: []s3response.Tag{
{Key: "key", Value: "value1"},
{Key: "key", Value: "value2"},
},
},
},
limit: TagLimitObject,
},
want: nil,
wantErr: s3err.GetAPIError(s3err.ErrDuplicateTagKey),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var data []byte
if tt.args.overrideXML != nil {
data = tt.args.overrideXML
} else {
var err error
data, err = xml.Marshal(tt.args.data)
if err != nil {
t.Fatalf("error marshalling input: %v", err)
}
}
got, err := ParseTagging(data, tt.args.limit)
if !errors.Is(err, tt.wantErr) {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
if err == nil && !reflect.DeepEqual(got, tt.want) {
t.Errorf("expected result %v, got %v", tt.want, got)
}
})
}
}
func TestValidateCopySource(t *testing.T) {
tests := []struct {
name string
copysource string
err error
}{
// invalid encoding
{"invalid encoding 1", "%", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 2", "%2", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 3", "%G1", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 4", "%1Z", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 5", "%0H", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 6", "%XY", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 7", "%E", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 8", "hello%", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 9", "%%", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 10", "%2Gmore", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 11", "100%%sure", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 12", "%#00", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 13", "%0%0", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
{"invalid encoding 14", "%?versionId=id", s3err.GetAPIError(s3err.ErrInvalidCopySourceEncoding)},
// invalid bucket name
{"invalid bucket name 1", "168.200.1.255/obj/foo", s3err.GetAPIError(s3err.ErrInvalidCopySourceBucket)},
{"invalid bucket name 2", "/0000:0db8:85a3:0000:0000:8a2e:0370:7224/smth", s3err.GetAPIError(s3err.ErrInvalidCopySourceBucket)},
{"invalid bucket name 3", "", s3err.GetAPIError(s3err.ErrInvalidCopySourceBucket)},
{"invalid bucket name 4", "//obj/foo", s3err.GetAPIError(s3err.ErrInvalidCopySourceBucket)},
{"invalid bucket name 5", "//obj/foo?versionId=id", s3err.GetAPIError(s3err.ErrInvalidCopySourceBucket)},
// invalid object name
{"invalid object name 1", "bucket/../foo", s3err.GetAPIError(s3err.ErrInvalidCopySourceObject)},
{"invalid object name 2", "bucket/", s3err.GetAPIError(s3err.ErrInvalidCopySourceObject)},
{"invalid object name 3", "bucket", s3err.GetAPIError(s3err.ErrInvalidCopySourceObject)},
{"invalid object name 4", "bucket/../foo/dir/../../../", s3err.GetAPIError(s3err.ErrInvalidCopySourceObject)},
{"invalid object name 5", "bucket/.?versionId=smth", s3err.GetAPIError(s3err.ErrInvalidCopySourceObject)},
// success
{"no error 1", "bucket/object", nil},
{"no error 2", "bucket/object/key", nil},
{"no error 3", "bucket/4*&(*&(89765))", nil},
{"no error 4", "bucket/foo/../bar", nil},
{"no error 5", "bucket/foo/bar/baz?versionId=id", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCopySource(tt.copysource)
assert.Equal(t, tt.err, err)
})
}
}