diff --git a/backend/azure/azure.go b/backend/azure/azure.go index 282525b..c8a7781 100644 --- a/backend/azure/azure.go +++ b/backend/azure/azure.go @@ -1041,10 +1041,10 @@ func (az *Azure) CreateMultipartUpload(ctx context.Context, input s3response.Cre for _, prt := range tagParts { p := strings.Split(prt, "=") if len(p) != 2 { - return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTag) + return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTagValue) } if len(p[0]) > 128 || len(p[1]) > 256 { - return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTag) + return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTagValue) } tags[p[0]] = p[1] } @@ -1827,7 +1827,7 @@ func parseTags(tagstr *string) (map[string]string, error) { for _, prt := range tagParts { p := strings.Split(prt, "=") if len(p) != 2 { - return nil, s3err.GetAPIError(s3err.ErrInvalidTag) + return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) } tags[p[0]] = p[1] } diff --git a/backend/azure/err.go b/backend/azure/err.go index 0300702..932209e 100644 --- a/backend/azure/err.go +++ b/backend/azure/err.go @@ -40,7 +40,7 @@ func azErrToS3err(azErr *azcore.ResponseError) s3err.APIError { case "BlobNotFound": return s3err.GetAPIError(s3err.ErrNoSuchKey) case "TagsTooLarge": - return s3err.GetAPIError(s3err.ErrInvalidTag) + return s3err.GetAPIError(s3err.ErrInvalidTagValue) case "Requested Range Not Satisfiable": return s3err.GetAPIError(s3err.ErrInvalidRange) } diff --git a/backend/common.go b/backend/common.go index 8a461a4..36f4fc7 100644 --- a/backend/common.go +++ b/backend/common.go @@ -227,10 +227,10 @@ func ParseObjectTags(t string) (map[string]string, error) { for _, prt := range tagParts { p := strings.Split(prt, "=") if len(p) != 2 { - return nil, s3err.GetAPIError(s3err.ErrInvalidTag) + return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) } if len(p[0]) > 128 || len(p[1]) > 256 { - return nil, s3err.GetAPIError(s3err.ErrInvalidTag) + return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) } tagging[p[0]] = p[1] } diff --git a/s3api/controllers/base.go b/s3api/controllers/base.go index 0666f28..8f3ba30 100644 --- a/s3api/controllers/base.go +++ b/s3api/controllers/base.go @@ -1211,13 +1211,9 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { if ctx.Request().URI().QueryArgs().Has("tagging") { parsedAcl := ctx.Locals("parsedAcl").(auth.ACL) - var bucketTagging s3response.TaggingInput - err := xml.Unmarshal(ctx.Body(), &bucketTagging) + tagging, err := utils.ParseTagging(ctx.Body(), utils.TagLimitBucket, c.debug) if err != nil { - if c.debug { - debuglogger.Logf("error unmarshalling bucket tagging: %v", err) - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest), + return SendResponse(ctx, err, &MetaOpts{ Logger: c.logger, MetricsMng: c.mm, @@ -1226,37 +1222,6 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { }) } - if len(bucketTagging.TagSet.Tags) > 50 { - if c.debug { - debuglogger.Logf("bucket tagging length exceeds 50: %v", len(bucketTagging.TagSet.Tags)) - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrBucketTaggingLimited), - &MetaOpts{ - Logger: c.logger, - MetricsMng: c.mm, - Action: metrics.ActionPutBucketTagging, - BucketOwner: parsedAcl.Owner, - }) - } - - tags := make(map[string]string, len(bucketTagging.TagSet.Tags)) - - for _, tag := range bucketTagging.TagSet.Tags { - if len(tag.Key) > 128 || len(tag.Value) > 256 { - if c.debug { - debuglogger.Logf("invalid long bucket tagging key/value") - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidTag), - &MetaOpts{ - Logger: c.logger, - MetricsMng: c.mm, - Action: metrics.ActionPutBucketTagging, - BucketOwner: parsedAcl.Owner, - }) - } - tags[tag.Key] = tag.Value - } - err = auth.VerifyAccess(ctx.Context(), c.be, auth.AccessOptions{ Readonly: c.readonly, Acl: parsedAcl, @@ -1276,7 +1241,7 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { }) } - err = c.be.PutBucketTagging(ctx.Context(), bucket, tags) + err = c.be.PutBucketTagging(ctx.Context(), bucket, tagging) return SendResponse(ctx, err, &MetaOpts{ Logger: c.logger, @@ -1873,13 +1838,9 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { } if ctx.Request().URI().QueryArgs().Has("tagging") { - var objTagging s3response.TaggingInput - err := xml.Unmarshal(ctx.Body(), &objTagging) + tagging, err := utils.ParseTagging(ctx.Body(), utils.TagLimitObject, c.debug) if err != nil { - if c.debug { - debuglogger.Logf("error unmarshalling object tagging: %v", err) - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest), + return SendResponse(ctx, err, &MetaOpts{ Logger: c.logger, MetricsMng: c.mm, @@ -1888,38 +1849,6 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { }) } - if len(objTagging.TagSet.Tags) > 10 { - if c.debug { - debuglogger.Logf("bucket tagging length exceeds 10: %v", len(objTagging.TagSet.Tags)) - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrObjectTaggingLimited), - &MetaOpts{ - Logger: c.logger, - MetricsMng: c.mm, - Action: metrics.ActionPutObjectTagging, - BucketOwner: parsedAcl.Owner, - }) - } - - tags := make(map[string]string, len(objTagging.TagSet.Tags)) - - for _, tag := range objTagging.TagSet.Tags { - if len(tag.Key) > 128 || len(tag.Value) > 256 { - if c.debug { - debuglogger.Logf("invalid tag key/value len: %q %q", - tag.Key, tag.Value) - } - return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidTag), - &MetaOpts{ - Logger: c.logger, - MetricsMng: c.mm, - Action: metrics.ActionPutObjectTagging, - BucketOwner: parsedAcl.Owner, - }) - } - tags[tag.Key] = tag.Value - } - err = auth.VerifyAccess(ctx.Context(), c.be, auth.AccessOptions{ Readonly: c.readonly, Acl: parsedAcl, @@ -1940,7 +1869,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { }) } - err = c.be.PutObjectTagging(ctx.Context(), bucket, keyStart, tags) + err = c.be.PutObjectTagging(ctx.Context(), bucket, keyStart, tagging) return SendResponse(ctx, err, &MetaOpts{ Logger: c.logger, diff --git a/s3api/utils/utils.go b/s3api/utils/utils.go index a95d035..86e8856 100644 --- a/s3api/utils/utils.go +++ b/s3api/utils/utils.go @@ -17,6 +17,7 @@ package utils import ( "bytes" "encoding/base64" + "encoding/xml" "errors" "fmt" "io" @@ -719,3 +720,74 @@ func ParseCreateMpChecksumHeaders(ctx *fiber.Ctx, debug bool) (types.ChecksumAlg return algo, chType, nil } + +// TagLimit specifies the allowed tag count in a tag set +type TagLimit int + +const ( + // Tag limit for bucket tagging + TagLimitBucket TagLimit = 50 + // Tag limit for object tagging + TagLimitObject TagLimit = 10 +) + +// Parses and validates tagging +func ParseTagging(data []byte, limit TagLimit, debug bool) (map[string]string, error) { + var tagging s3response.TaggingInput + err := xml.Unmarshal(data, &tagging) + if err != nil { + if debug { + debuglogger.Logf("invalid taggging: %s", data) + } + return nil, s3err.GetAPIError(s3err.ErrMalformedXML) + } + + tLen := len(tagging.TagSet.Tags) + if tLen > int(limit) { + switch limit { + case TagLimitObject: + if debug { + debuglogger.Logf("bucket tagging length exceeds %v: %v", limit, tLen) + } + return nil, s3err.GetAPIError(s3err.ErrObjectTaggingLimited) + case TagLimitBucket: + if debug { + debuglogger.Logf("object tagging length exceeds %v: %v", limit, tLen) + } + return nil, s3err.GetAPIError(s3err.ErrBucketTaggingLimited) + } + } + + tagSet := make(map[string]string, tLen) + + for _, tag := range tagging.TagSet.Tags { + // validate tag key + if len(tag.Key) == 0 || len(tag.Key) > 128 { + if debug { + debuglogger.Logf("tag key should 0 < tag.Key <= 128, key: %v", tag.Key) + } + return nil, s3err.GetAPIError(s3err.ErrInvalidTagKey) + } + + // validate tag value + if len(tag.Value) > 256 { + if debug { + debuglogger.Logf("invalid long tag value: (length): %v, (value): %v", len(tag.Value), tag.Value) + } + return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) + } + + // make sure there are no duplicate keys + _, ok := tagSet[tag.Key] + if ok { + if debug { + debuglogger.Logf("duplicate tag key: %v", tag.Key) + } + return nil, s3err.GetAPIError(s3err.ErrDuplicateTagKey) + } + + tagSet[tag.Key] = tag.Value + } + + return tagSet, nil +} diff --git a/s3api/utils/utils_test.go b/s3api/utils/utils_test.go index 205115d..9db30f1 100644 --- a/s3api/utils/utils_test.go +++ b/s3api/utils/utils_test.go @@ -16,6 +16,9 @@ package utils import ( "bytes" + "encoding/xml" + "errors" + "math/rand" "net/http" "reflect" "testing" @@ -25,6 +28,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/valyala/fasthttp" "github.com/versity/versitygw/backend" + "github.com/versity/versitygw/s3err" "github.com/versity/versitygw/s3response" ) @@ -857,3 +861,163 @@ func Test_checkChecksumTypeAndAlgo(t *testing.T) { }) } } + +func TestParseTagging(t *testing.T) { + genRandStr := func(lgth int) string { + b := make([]byte, lgth) + for i := range b { + b[i] = byte(rand.Intn(95) + 32) // 126 - 32 + 1 = 95 printable characters + } + return string(b) + } + getTagSet := func(lgth int) s3response.TaggingInput { + res := s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{}, + }, + } + + for i := 0; i < lgth; i++ { + res.TagSet.Tags = append(res.TagSet.Tags, s3response.Tag{ + Key: genRandStr(10), + Value: genRandStr(20), + }) + } + + return res + } + type args struct { + data s3response.TaggingInput + overrideXML []byte + limit TagLimit + debug bool + } + tests := []struct { + name string + args args + want map[string]string + wantErr error + }{ + { + name: "valid tags within limit", + args: args{ + data: s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{ + {Key: "key1", Value: "value1"}, + {Key: "key2", Value: "value2"}, + }, + }, + }, + limit: TagLimitObject, + }, + want: map[string]string{"key1": "value1", "key2": "value2"}, + wantErr: nil, + }, + { + name: "malformed XML", + args: args{ + overrideXML: []byte("invalid xml"), + limit: TagLimitObject, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrMalformedXML), + }, + { + name: "exceeds bucket tag limit", + args: args{ + data: getTagSet(51), + limit: TagLimitBucket, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrBucketTaggingLimited), + }, + { + name: "exceeds object tag limit", + args: args{ + data: getTagSet(11), + limit: TagLimitObject, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrObjectTaggingLimited), + }, + { + name: "invalid 0 length tag key", + args: args{ + data: s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{{Key: "", Value: "value1"}}, + }, + }, + limit: TagLimitObject, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrInvalidTagKey), + }, + { + name: "invalid long tag key", + args: args{ + data: s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{{Key: genRandStr(130), Value: "value1"}}, + }, + }, + limit: TagLimitObject, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrInvalidTagKey), + }, + { + name: "invalid long tag value", + args: args{ + data: s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{{Key: "key", Value: genRandStr(257)}}, + }, + }, + limit: TagLimitBucket, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrInvalidTagValue), + }, + { + name: "duplicate tag key", + args: args{ + data: s3response.TaggingInput{ + TagSet: s3response.TagSet{ + Tags: []s3response.Tag{ + {Key: "key", Value: "value1"}, + {Key: "key", Value: "value2"}, + }, + }, + }, + limit: TagLimitObject, + }, + want: nil, + wantErr: s3err.GetAPIError(s3err.ErrDuplicateTagKey), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var data []byte + if tt.args.overrideXML != nil { + data = tt.args.overrideXML + } else { + var err error + data, err = xml.Marshal(tt.args.data) + if err != nil { + t.Fatalf("error marshalling input: %v", err) + } + } + got, err := ParseTagging(data, tt.args.limit, tt.args.debug) + + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + if err == nil && !reflect.DeepEqual(got, tt.want) { + t.Errorf("expected result %v, got %v", tt.want, got) + } + }) + } +} diff --git a/s3err/s3err.go b/s3err/s3err.go index 82f8b99..a895285 100644 --- a/s3err/s3err.go +++ b/s3err/s3err.go @@ -84,7 +84,9 @@ const ( ErrInvalidCopyDest ErrInvalidCopySource ErrInvalidCopySourceRange - ErrInvalidTag + ErrInvalidTagKey + ErrInvalidTagValue + ErrDuplicateTagKey ErrBucketTaggingLimited ErrObjectTaggingLimited ErrAuthHeaderEmpty @@ -308,11 +310,21 @@ var errorCodeResponse = map[ErrorCode]APIError{ Description: "The x-amz-copy-source-range value must be of the form bytes=first-last where first and last are the zero-based offsets of the first and last bytes to copy", HTTPStatusCode: http.StatusBadRequest, }, - ErrInvalidTag: { + ErrInvalidTagKey: { + Code: "InvalidTag", + Description: "The TagKey you have provided is invalid", + HTTPStatusCode: http.StatusBadRequest, + }, + ErrInvalidTagValue: { Code: "InvalidTag", Description: "The TagValue you have provided is invalid", HTTPStatusCode: http.StatusBadRequest, }, + ErrDuplicateTagKey: { + Code: "InvalidTag", + Description: "Cannot provide multiple Tags with the same key", + HTTPStatusCode: http.StatusBadRequest, + }, ErrBucketTaggingLimited: { Code: "BadRequest", Description: "Bucket tag count cannot be greater than 50", diff --git a/tests/integration/group-tests.go b/tests/integration/group-tests.go index 1f46b57..93af571 100644 --- a/tests/integration/group-tests.go +++ b/tests/integration/group-tests.go @@ -118,6 +118,7 @@ func TestDeleteBucketOwnershipControls(s *S3Conf) { func TestPutBucketTagging(s *S3Conf) { PutBucketTagging_non_existing_bucket(s) PutBucketTagging_long_tags(s) + PutBucketTagging_duplicate_keys(s) PutBucketTagging_tag_count_limit(s) PutBucketTagging_success(s) PutBucketTagging_success_status(s) @@ -298,6 +299,7 @@ func TestCopyObject(s *S3Conf) { func TestPutObjectTagging(s *S3Conf) { PutObjectTagging_non_existing_object(s) PutObjectTagging_long_tags(s) + PutObjectTagging_duplicate_keys(s) PutObjectTagging_tag_count_limit(s) PutObjectTagging_success(s) } @@ -852,6 +854,7 @@ func GetIntTests() IntTests { "DeleteBucketOwnershipControls_success": DeleteBucketOwnershipControls_success, "PutBucketTagging_non_existing_bucket": PutBucketTagging_non_existing_bucket, "PutBucketTagging_long_tags": PutBucketTagging_long_tags, + "PutBucketTagging_duplicate_keys": PutBucketTagging_duplicate_keys, "PutBucketTagging_tag_count_limit": PutBucketTagging_tag_count_limit, "PutBucketTagging_success": PutBucketTagging_success, "PutBucketTagging_success_status": PutBucketTagging_success_status, @@ -959,6 +962,7 @@ func GetIntTests() IntTests { "CopyObject_success": CopyObject_success, "PutObjectTagging_non_existing_object": PutObjectTagging_non_existing_object, "PutObjectTagging_long_tags": PutObjectTagging_long_tags, + "PutObjectTagging_duplicate_keys": PutObjectTagging_duplicate_keys, "PutObjectTagging_tag_count_limit": PutObjectTagging_tag_count_limit, "PutObjectTagging_success": PutObjectTagging_success, "GetObjectTagging_non_existing_object": GetObjectTagging_non_existing_object, diff --git a/tests/integration/tests.go b/tests/integration/tests.go index 6f634eb..914fda8 100644 --- a/tests/integration/tests.go +++ b/tests/integration/tests.go @@ -2500,7 +2500,7 @@ func PutBucketTagging_long_tags(s *S3Conf) error { Bucket: &bucket, Tagging: &tagging}) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagKey)); err != nil { return err } @@ -2511,7 +2511,32 @@ func PutBucketTagging_long_tags(s *S3Conf) error { Bucket: &bucket, Tagging: &tagging}) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { + return err + } + + return nil + }) +} + +func PutBucketTagging_duplicate_keys(s *S3Conf) error { + testName := "PutBucketTagging_duplicate_keys" + return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { + tagging := types.Tagging{ + TagSet: []types.Tag{ + {Key: getPtr("key"), Value: getPtr("value")}, + {Key: getPtr("key"), Value: getPtr("value-1")}, + {Key: getPtr("key-1"), Value: getPtr("value-2")}, + {Key: getPtr("key-2"), Value: getPtr("value-3")}, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + _, err := s3client.PutBucketTagging(ctx, &s3.PutBucketTaggingInput{ + Bucket: &bucket, + Tagging: &tagging, + }) + cancel() + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrDuplicateTagKey)); err != nil { return err } @@ -2813,7 +2838,7 @@ func PutObject_invalid_long_tags(s *S3Conf) error { Tagging: &tagging, }) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { return err } @@ -2827,7 +2852,7 @@ func PutObject_invalid_long_tags(s *S3Conf) error { }) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { return err } @@ -6991,7 +7016,7 @@ func PutObjectTagging_long_tags(s *S3Conf) error { Key: &obj, Tagging: &tagging}) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagKey)); err != nil { return err } @@ -7003,7 +7028,39 @@ func PutObjectTagging_long_tags(s *S3Conf) error { Key: &obj, Tagging: &tagging}) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { + return err + } + + return nil + }) +} + +func PutObjectTagging_duplicate_keys(s *S3Conf) error { + testName := "PutObjectTagging_duplicate_keys" + return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { + obj := "my-obj" + _, err := putObjects(s3client, []string{obj}, bucket) + if err != nil { + return err + } + + tagging := types.Tagging{ + TagSet: []types.Tag{ + {Key: getPtr("key-1"), Value: getPtr("value-1")}, + {Key: getPtr("key-2"), Value: getPtr("value-2")}, + {Key: getPtr("same-key"), Value: getPtr("value-3")}, + {Key: getPtr("same-key"), Value: getPtr("value-4")}, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + _, err = s3client.PutObjectTagging(ctx, &s3.PutObjectTaggingInput{ + Bucket: &bucket, + Key: &obj, + Tagging: &tagging, + }) + cancel() + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrDuplicateTagKey)); err != nil { return err } @@ -7630,7 +7687,7 @@ func CreateMultipartUpload_with_invalid_tagging(s *S3Conf) error { Tagging: getPtr("invalid_tag"), }) cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTag)); err != nil { + if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { return err }