diff --git a/backend/azure/azure.go b/backend/azure/azure.go index c8a77813..e64ebd35 100644 --- a/backend/azure/azure.go +++ b/backend/azure/azure.go @@ -295,7 +295,7 @@ func (az *Azure) DeleteBucketOwnershipControls(ctx context.Context, bucket strin } func (az *Azure) PutObject(ctx context.Context, po s3response.PutObjectInput) (s3response.PutObjectOutput, error) { - tags, err := parseTags(po.Tagging) + tags, err := backend.ParseObjectTags(getString(po.Tagging)) if err != nil { return s3response.PutObjectOutput{}, err } @@ -872,7 +872,7 @@ func (az *Azure) CopyObject(ctx context.Context, input s3response.CopyObjectInpu // Set object Tagging, if tagging directive is "REPLACE" if input.TaggingDirective == types.TaggingDirectiveReplace { - tags, err := parseTags(input.Tagging) + tags, err := backend.ParseObjectTags(getString(input.Tagging)) if err != nil { return nil, err } @@ -1034,20 +1034,9 @@ func (az *Azure) CreateMultipartUpload(ctx context.Context, input s3response.Cre } // parse object tags - tagsStr := getString(input.Tagging) - tags := map[string]string{} - if tagsStr != "" { - tagParts := strings.Split(tagsStr, "&") - for _, prt := range tagParts { - p := strings.Split(prt, "=") - if len(p) != 2 { - return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTagValue) - } - if len(p[0]) > 128 || len(p[1]) > 256 { - return s3response.InitiateMultipartUploadResult{}, s3err.GetAPIError(s3err.ErrInvalidTagValue) - } - tags[p[0]] = p[1] - } + tags, err := backend.ParseObjectTags(getString(input.Tagging)) + if err != nil { + return s3response.InitiateMultipartUploadResult{}, err } // set blob legal hold status in metadata @@ -1087,7 +1076,7 @@ func (az *Azure) CreateMultipartUpload(ctx context.Context, input s3response.Cre // Create and empty blob in .sgwtmp/multipart// // The blob indicates multipart upload initialization and holds the mp metadata // e.g tagging, content-type, metadata, object lock status ... - _, err := az.client.UploadBuffer(ctx, *input.Bucket, tmpPath, []byte{}, opts) + _, err = az.client.UploadBuffer(ctx, *input.Bucket, tmpPath, []byte{}, opts) if err != nil { return s3response.InitiateMultipartUploadResult{}, azureErrToS3Err(err) } @@ -1818,24 +1807,6 @@ func parseAzMetadata(m map[string]*string) map[string]string { return meta } -func parseTags(tagstr *string) (map[string]string, error) { - tagsStr := getString(tagstr) - tags := make(map[string]string) - - if tagsStr != "" { - tagParts := strings.Split(tagsStr, "&") - for _, prt := range tagParts { - p := strings.Split(prt, "=") - if len(p) != 2 { - return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) - } - tags[p[0]] = p[1] - } - } - - return tags, nil -} - func parseAzTags(tagSet []*blob.Tags) map[string]string { tags := map[string]string{} for _, tag := range tagSet { diff --git a/backend/common.go b/backend/common.go index 36f4fc75..741bb480 100644 --- a/backend/common.go +++ b/backend/common.go @@ -21,7 +21,9 @@ import ( "fmt" "io" "io/fs" + "net/url" "os" + "regexp" "strconv" "strings" "syscall" @@ -215,27 +217,81 @@ func ParseCopySource(copySourceHeader string) (string, string, string, error) { } // ParseObjectTags parses the url encoded input string into -// map[string]string key-value tag set -func ParseObjectTags(t string) (map[string]string, error) { - if t == "" { +// map[string]string with unescaped key/value pair +func ParseObjectTags(tagging string) (map[string]string, error) { + if tagging == "" { return nil, nil } - tagging := make(map[string]string) + tagSet := make(map[string]string) - tagParts := strings.Split(t, "&") - for _, prt := range tagParts { - p := strings.Split(prt, "=") - if len(p) != 2 { + for tagging != "" { + var tag string + tag, tagging, _ = strings.Cut(tagging, "&") + // if 'tag' before the first appearance of '&' is empty continue + if tag == "" { + continue + } + + key, value, found := strings.Cut(tag, "=") + // if key is empty, but "=" is present, return invalid url ecnoding err + if found && key == "" { + return nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging) + } + + // return invalid tag key, if the key is longer than 128 + if len(key) > 128 { + return nil, s3err.GetAPIError(s3err.ErrInvalidTagKey) + } + + // return invalid tag value, if tag value is longer than 256 + if len(value) > 256 { return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) } - if len(p[0]) > 128 || len(p[1]) > 256 { + + // query unescape tag key + key, err := url.QueryUnescape(key) + if err != nil { + return nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging) + } + + // query unescape tag value + value, err = url.QueryUnescape(value) + if err != nil { + return nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging) + } + + // check tag key to be valid + if !isValidTagComponent(key) { + return nil, s3err.GetAPIError(s3err.ErrInvalidTagKey) + } + + // check tag value to be valid + if !isValidTagComponent(value) { return nil, s3err.GetAPIError(s3err.ErrInvalidTagValue) } - tagging[p[0]] = p[1] + + // duplicate keys are not allowed: return invalid url encoding err + _, ok := tagSet[key] + if ok { + return nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging) + } + + tagSet[key] = value } - return tagging, nil + return tagSet, nil +} + +var validTagComponent = regexp.MustCompile(`^[a-zA-Z0-9:/_.\-+ ]+$`) + +// isValidTagComponent matches strings which contain letters, decimal digits, +// and special chars: '/', '_', '-', '+', '.', ' ' (space) +func isValidTagComponent(str string) bool { + if str == "" { + return true + } + return validTagComponent.Match([]byte(str)) } func GetMultipartMD5(parts []types.CompletedPart) string { diff --git a/s3err/s3err.go b/s3err/s3err.go index a8952856..2d646058 100644 --- a/s3err/s3err.go +++ b/s3err/s3err.go @@ -89,6 +89,7 @@ const ( ErrDuplicateTagKey ErrBucketTaggingLimited ErrObjectTaggingLimited + ErrInvalidURLEncodedTagging ErrAuthHeaderEmpty ErrSignatureVersionNotSupported ErrMalformedPOSTRequest @@ -335,6 +336,11 @@ var errorCodeResponse = map[ErrorCode]APIError{ Description: "Object tags cannot be greater than 10", HTTPStatusCode: http.StatusBadRequest, }, + ErrInvalidURLEncodedTagging: { + Code: "InvalidArgument", + Description: "The header 'x-amz-tagging' shall be encoded as UTF-8 then URLEncoded URL query parameters without tag name duplicates.", + HTTPStatusCode: http.StatusBadRequest, + }, ErrMalformedXML: { Code: "MalformedXML", Description: "The XML you provided was not well-formed or did not validate against our published schema.", diff --git a/tests/integration/group-tests.go b/tests/integration/group-tests.go index 93af5711..73888b15 100644 --- a/tests/integration/group-tests.go +++ b/tests/integration/group-tests.go @@ -139,7 +139,7 @@ func TestDeleteBucketTagging(s *S3Conf) { func TestPutObject(s *S3Conf) { PutObject_non_existing_bucket(s) PutObject_special_chars(s) - PutObject_invalid_long_tags(s) + PutObject_tagging(s) PutObject_missing_object_lock_retention_config(s) PutObject_with_object_lock(s) PutObject_invalid_legal_hold(s) @@ -274,6 +274,7 @@ func TestCopyObject(s *S3Conf) { CopyObject_not_owned_source_bucket(s) CopyObject_copy_to_itself(s) CopyObject_copy_to_itself_invalid_directive(s) + CopyObject_should_replace_tagging(s) CopyObject_should_copy_tagging(s) CopyObject_invalid_tagging_directive(s) CopyObject_to_itself_with_new_metadata(s) @@ -320,7 +321,6 @@ func TestDeleteObjectTagging(s *S3Conf) { func TestCreateMultipartUpload(s *S3Conf) { CreateMultipartUpload_non_existing_bucket(s) CreateMultipartUpload_with_metadata(s) - CreateMultipartUpload_with_invalid_tagging(s) CreateMultipartUpload_with_tagging(s) CreateMultipartUpload_with_object_lock(s) CreateMultipartUpload_with_object_lock_not_enabled(s) @@ -866,7 +866,7 @@ func GetIntTests() IntTests { "DeleteBucketTagging_success": DeleteBucketTagging_success, "PutObject_non_existing_bucket": PutObject_non_existing_bucket, "PutObject_special_chars": PutObject_special_chars, - "PutObject_invalid_long_tags": PutObject_invalid_long_tags, + "PutObject_tagging": PutObject_tagging, "PutObject_success": PutObject_success, "PutObject_racey_success": PutObject_racey_success, "HeadObject_non_existing_object": HeadObject_non_existing_object, @@ -943,6 +943,7 @@ func GetIntTests() IntTests { "CopyObject_not_owned_source_bucket": CopyObject_not_owned_source_bucket, "CopyObject_copy_to_itself": CopyObject_copy_to_itself, "CopyObject_copy_to_itself_invalid_directive": CopyObject_copy_to_itself_invalid_directive, + "CopyObject_should_replace_tagging": CopyObject_should_replace_tagging, "CopyObject_should_copy_tagging": CopyObject_should_copy_tagging, "CopyObject_invalid_tagging_directive": CopyObject_invalid_tagging_directive, "CopyObject_to_itself_with_new_metadata": CopyObject_to_itself_with_new_metadata, @@ -974,7 +975,6 @@ func GetIntTests() IntTests { "DeleteObjectTagging_success": DeleteObjectTagging_success, "CreateMultipartUpload_non_existing_bucket": CreateMultipartUpload_non_existing_bucket, "CreateMultipartUpload_with_metadata": CreateMultipartUpload_with_metadata, - "CreateMultipartUpload_with_invalid_tagging": CreateMultipartUpload_with_invalid_tagging, "CreateMultipartUpload_with_tagging": CreateMultipartUpload_with_tagging, "CreateMultipartUpload_with_object_lock": CreateMultipartUpload_with_object_lock, "CreateMultipartUpload_with_object_lock_not_enabled": CreateMultipartUpload_with_object_lock_not_enabled, diff --git a/tests/integration/tests.go b/tests/integration/tests.go index 914fda88..4004c70a 100644 --- a/tests/integration/tests.go +++ b/tests/integration/tests.go @@ -2825,37 +2825,111 @@ func PutObject_special_chars(s *S3Conf) error { }) } -func PutObject_invalid_long_tags(s *S3Conf) error { - testName := "PutObject_invalid_long_tags" +func PutObject_tagging(s *S3Conf) error { + testName := "PutObject_tagging" return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { - key := "my-obj" - tagging := fmt.Sprintf("%v=val", genRandString(200)) + obj := "my-obj" + testTagging := func(taggging string, result map[string]string, expectedErr error) error { + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) - ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) - _, err := s3client.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &bucket, - Key: &key, - Tagging: &tagging, - }) - cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { - return err + _, err := s3client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &bucket, + Key: &obj, + Tagging: &taggging, + }) + cancel() + if err == nil && expectedErr != nil { + return fmt.Errorf("expected err %w, instead got nil", expectedErr) + } + if err != nil { + if expectedErr == nil { + return err + } + switch eErr := expectedErr.(type) { + case s3err.APIError: + return checkApiErr(err, eErr) + default: + return fmt.Errorf("invalid err provided: %w", expectedErr) + } + } + + ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) + res, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{ + Bucket: &bucket, + Key: &obj, + }) + cancel() + if err != nil { + return err + } + + if len(res.TagSet) != len(result) { + return fmt.Errorf("tag lengths are not equal: (expected): %v, (got): %v", len(result), len(res.TagSet)) + } + + for _, tag := range res.TagSet { + val, ok := result[getString(tag.Key)] + if !ok { + return fmt.Errorf("tag key not found: %v", getString(tag.Key)) + } + + if val != getString(tag.Value) { + return fmt.Errorf("expected the %v tag value to be %v, instead got %v", getString(tag.Key), val, getString(tag.Value)) + } + } + + return nil } - tagging = fmt.Sprintf("key=%v", genRandString(300)) - - ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - _, err = s3client.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &bucket, - Key: &key, - Tagging: &tagging, - }) - cancel() - - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { - return err + for _, el := range []struct { + tagging string + result map[string]string + expectedErr error + }{ + // success cases + {"&", map[string]string{}, nil}, + {"&&&", map[string]string{}, nil}, + {"key", map[string]string{"key": ""}, nil}, + {"key&", map[string]string{"key": ""}, nil}, + {"key=&", map[string]string{"key": ""}, nil}, + {"key=val&", map[string]string{"key": "val"}, nil}, + {"key1&key2", map[string]string{"key1": "", "key2": ""}, nil}, + {"key1=val1&key2=val2", map[string]string{"key1": "val1", "key2": "val2"}, nil}, + // invalid url-encoded + {"=", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + {"key%", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // duplicate keys + {"key=val&key=val", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // invalid tag keys + {"key?=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key(=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key*=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key$=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key#=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key@=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key!=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + // invalid tag values + {"key=val?", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val(", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val*", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val$", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val#", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val@", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val!", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + // success special chars + {"key-key_key.key/key=value-value_value.value/value", map[string]string{"key-key_key.key/key": "value-value_value.value/value"}, nil}, + // should handle supported encoded characters + {"key%2E=value%2F", map[string]string{"key.": "value/"}, nil}, + {"key%2D=value%2B", map[string]string{"key-": "value+"}, nil}, + {"key++key=value++value", map[string]string{"key key": "value value"}, nil}, + {"key%20key=value%20value", map[string]string{"key key": "value value"}, nil}, + {"key%5Fkey=value%5Fvalue", map[string]string{"key_key": "value_value"}, nil}, + } { + err := testTagging(el.tagging, el.result, el.expectedErr) + if err != nil { + return err + } } - return nil }) } @@ -6201,55 +6275,121 @@ func CopyObject_should_copy_tagging(s *S3Conf) error { }) } -func CopyObject_should_reaplace_tagging(s *S3Conf) error { - testName := "CopyObject_should_reaplace_tagging" +func CopyObject_should_replace_tagging(s *S3Conf) error { + testName := "CopyObject_should_replace_tagging" return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { - srcObj, dstObj := "source-object", "dest-object" - tagging := "foo=bar&baz=quxx" - - _, err := putObjectWithData(100, &s3.PutObjectInput{ + obj := "my-obj" + _, err := putObjectWithData(10, &s3.PutObjectInput{ Bucket: &bucket, - Key: &srcObj, - Tagging: &tagging, + Key: &obj, + Tagging: getPtr("key=value&key1=value1"), }, s3client) if err != nil { return err } + testTagging := func(taggging string, result map[string]string, expectedErr error) error { + dstObj := "destination-object" + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + _, err := s3client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: &bucket, + Key: &dstObj, + Tagging: &taggging, + CopySource: getPtr(fmt.Sprintf("%v/%v", bucket, obj)), + TaggingDirective: types.TaggingDirectiveReplace, + }) + cancel() + if err == nil && expectedErr != nil { + return fmt.Errorf("expected err %w, instead got nil", expectedErr) + } + if err != nil { + if expectedErr == nil { + return err + } + switch eErr := expectedErr.(type) { + case s3err.APIError: + return checkApiErr(err, eErr) + default: + return fmt.Errorf("invalid err provided: %w", expectedErr) + } + } - copyTagging := "key1=val1&key2=val2" + ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) + res, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{ + Bucket: &bucket, + Key: &dstObj, + }) + cancel() + if err != nil { + return err + } - ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) - _, err = s3client.CopyObject(ctx, &s3.CopyObjectInput{ - Bucket: &bucket, - Key: &dstObj, - CopySource: getPtr(fmt.Sprintf("%v/%v", bucket, srcObj)), - TaggingDirective: types.TaggingDirectiveReplace, - Tagging: ©Tagging, - }) - cancel() - if err != nil { - return err + if len(res.TagSet) != len(result) { + return fmt.Errorf("tag lengths are not equal: (expected): %v, (got): %v", len(result), len(res.TagSet)) + } + + for _, tag := range res.TagSet { + val, ok := result[getString(tag.Key)] + if !ok { + return fmt.Errorf("tag key not found: %v", getString(tag.Key)) + } + + if val != getString(tag.Value) { + return fmt.Errorf("expected the %v tag value to be %v, instead got %v", getString(tag.Key), val, getString(tag.Value)) + } + } + + return nil } - ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - res, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{ - Bucket: &bucket, - Key: &dstObj, - }) - cancel() - if err != nil { - return err + for _, el := range []struct { + tagging string + result map[string]string + expectedErr error + }{ + // success cases + {"&", map[string]string{}, nil}, + {"&&&", map[string]string{}, nil}, + {"key", map[string]string{"key": ""}, nil}, + {"key&", map[string]string{"key": ""}, nil}, + {"key=&", map[string]string{"key": ""}, nil}, + {"key=val&", map[string]string{"key": "val"}, nil}, + {"key1&key2", map[string]string{"key1": "", "key2": ""}, nil}, + {"key1=val1&key2=val2", map[string]string{"key1": "val1", "key2": "val2"}, nil}, + // invalid url-encoded + {"=", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + {"key%", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // duplicate keys + {"key=val&key=val", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // invalid tag keys + {"key?=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key(=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key*=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key$=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key#=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key@=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key!=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + // invalid tag values + {"key=val?", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val(", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val*", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val$", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val#", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val@", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val!", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + // success special chars + {"key-key_key.key/key=value-value_value.value/value", map[string]string{"key-key_key.key/key": "value-value_value.value/value"}, nil}, + // should handle supported encoded characters + {"key%2E=value%2F", map[string]string{"key.": "value/"}, nil}, + {"key%2D=value%2B", map[string]string{"key-": "value+"}, nil}, + {"key++key=value++value", map[string]string{"key key": "value value"}, nil}, + {"key%20key=value%20value", map[string]string{"key key": "value value"}, nil}, + {"key%5Fkey=value%5Fvalue", map[string]string{"key_key": "value_value"}, nil}, + } { + err := testTagging(el.tagging, el.result, el.expectedErr) + if err != nil { + return err + } } - - expectedTagSet := []types.Tag{ - {Key: getPtr("key1"), Value: getPtr("val1")}, - {Key: getPtr("key2"), Value: getPtr("val2")}, - } - - if !areTagsSame(res.TagSet, expectedTagSet) { - return fmt.Errorf("expected the tag set to be %v, instead got %v", expectedTagSet, res.TagSet) - } - return nil }) } @@ -7676,93 +7816,138 @@ func CreateMultipartUpload_valid_checksum_algorithm(s *S3Conf) error { }) } -func CreateMultipartUpload_with_invalid_tagging(s *S3Conf) error { - testName := "CreateMultipartUpload_with_invalid_tagging" - return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { - obj := "my-obj" - ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) - _, err := s3client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ - Bucket: &bucket, - Key: &obj, - Tagging: getPtr("invalid_tag"), - }) - cancel() - if err := checkApiErr(err, s3err.GetAPIError(s3err.ErrInvalidTagValue)); err != nil { - return err - } - - return nil - }) -} - func CreateMultipartUpload_with_tagging(s *S3Conf) error { testName := "CreateMultipartUpload_with_tagging" return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { obj := "my-obj" - tagging := "key1=val1&key2=val2" - ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) - out, err := s3client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ - Bucket: &bucket, - Key: &obj, - Tagging: &tagging, - }) - cancel() - if err != nil { - return err - } - - parts, _, err := uploadParts(s3client, 100, 1, bucket, obj, *out.UploadId) - if err != nil { - return err - } - - compParts := []types.CompletedPart{} - for _, el := range parts { - compParts = append(compParts, types.CompletedPart{ - ETag: el.ETag, - PartNumber: el.PartNumber, + testTagging := func(tagging string, result map[string]string, expectedErr error) error { + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + mp, err := s3client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: &bucket, + Key: &obj, + Tagging: &tagging, }) + cancel() + if err == nil && expectedErr != nil { + return fmt.Errorf("expected err %w, instead got nil", expectedErr) + } + if err != nil { + if expectedErr == nil { + return err + } + switch eErr := expectedErr.(type) { + case s3err.APIError: + return checkApiErr(err, eErr) + default: + return fmt.Errorf("invalid err provided: %w", expectedErr) + } + } + + parts, _, err := uploadParts(s3client, 5*1024*1024, 1, bucket, obj, *mp.UploadId) + if err != nil { + return err + } + + cParts := []types.CompletedPart{ + { + ETag: parts[0].ETag, + PartNumber: parts[0].PartNumber, + ChecksumCRC32: parts[0].ChecksumCRC32, + }, + } + + ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) + _, err = s3client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: &bucket, + Key: &obj, + UploadId: mp.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: cParts, + }, + }) + cancel() + if err != nil { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) + res, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{ + Bucket: &bucket, + Key: &obj, + }) + cancel() + if err != nil { + return err + } + + if len(res.TagSet) != len(result) { + return fmt.Errorf("tag lengths are not equal: (expected): %v, (got): %v", len(result), len(res.TagSet)) + } + + for _, tag := range res.TagSet { + val, ok := result[getString(tag.Key)] + if !ok { + return fmt.Errorf("tag key not found: %v", getString(tag.Key)) + } + + if val != getString(tag.Value) { + return fmt.Errorf("expected the %v tag value to be %v, instead got %v", getString(tag.Key), val, getString(tag.Value)) + } + } + + return nil } - ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - _, err = s3client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ - Bucket: &bucket, - Key: &obj, - UploadId: out.UploadId, - MultipartUpload: &types.CompletedMultipartUpload{ - Parts: compParts, - }, - }) - cancel() - if err != nil { - return err + for _, el := range []struct { + tagging string + result map[string]string + expectedErr error + }{ + // success cases + {"&", map[string]string{}, nil}, + {"&&&", map[string]string{}, nil}, + {"key", map[string]string{"key": ""}, nil}, + {"key&", map[string]string{"key": ""}, nil}, + {"key=&", map[string]string{"key": ""}, nil}, + {"key=val&", map[string]string{"key": "val"}, nil}, + {"key1&key2", map[string]string{"key1": "", "key2": ""}, nil}, + {"key1=val1&key2=val2", map[string]string{"key1": "val1", "key2": "val2"}, nil}, + // invalid url-encoded + {"=", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + {"key%", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // duplicate keys + {"key=val&key=val", nil, s3err.GetAPIError(s3err.ErrInvalidURLEncodedTagging)}, + // invalid tag keys + {"key?=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key(=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key*=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key$=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key#=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key@=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + {"key!=val", nil, s3err.GetAPIError(s3err.ErrInvalidTagKey)}, + // invalid tag values + {"key=val?", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val(", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val*", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val$", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val#", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val@", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + {"key=val!", nil, s3err.GetAPIError(s3err.ErrInvalidTagValue)}, + // success special chars + {"key-key_key.key/key=value-value_value.value/value", map[string]string{"key-key_key.key/key": "value-value_value.value/value"}, nil}, + // should handle supported encoded characters + {"key%2E=value%2F", map[string]string{"key.": "value/"}, nil}, + {"key%2D=value%2B", map[string]string{"key-": "value+"}, nil}, + {"key++key=value++value", map[string]string{"key key": "value value"}, nil}, + {"key%20key=value%20value", map[string]string{"key key": "value value"}, nil}, + {"key%5Fkey=value%5Fvalue", map[string]string{"key_key": "value_value"}, nil}, + } { + err := testTagging(el.tagging, el.result, el.expectedErr) + if err != nil { + fmt.Println("failing for: ", el.tagging) + return err + } } - - ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) - resp, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{ - Bucket: &bucket, - Key: &obj, - }) - cancel() - if err != nil { - return err - } - - expectedOutput := []types.Tag{ - { - Key: getPtr("key1"), - Value: getPtr("val1"), - }, - { - Key: getPtr("key2"), - Value: getPtr("val2"), - }, - } - - if !areTagsSame(resp.TagSet, expectedOutput) { - return fmt.Errorf("expected object tagging to be %v, instead got %v", expectedOutput, resp.TagSet) - } - return nil }) }