diff --git a/backend/posix/posix.go b/backend/posix/posix.go index f7c0b37..115ec67 100644 --- a/backend/posix/posix.go +++ b/backend/posix/posix.go @@ -1746,7 +1746,7 @@ func (p *Posix) CompleteMultipartUploadWithCopy(ctx context.Context, input *s3.C var sum string switch checksums.Type { case types.ChecksumTypeComposite: - sum = compositeChecksumRdr.Sum() + sum = fmt.Sprintf("%s-%v", compositeChecksumRdr.Sum(), len(parts)) case types.ChecksumTypeFullObject: if !composableCRC { sum = hashRdr.Sum() @@ -1755,38 +1755,45 @@ func (p *Posix) CompleteMultipartUploadWithCopy(ctx context.Context, input *s3.C } } + var gotSum *string + switch checksumAlgorithm { case types.ChecksumAlgorithmCrc32: - if input.ChecksumCRC32 != nil && *input.ChecksumCRC32 != sum { - return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) - } + gotSum = input.ChecksumCRC32 checksum.CRC32 = &sum crc32 = &sum case types.ChecksumAlgorithmCrc32c: - if input.ChecksumCRC32C != nil && *input.ChecksumCRC32C != sum { - return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) - } + gotSum = input.ChecksumCRC32C checksum.CRC32C = &sum crc32c = &sum case types.ChecksumAlgorithmSha1: - if input.ChecksumSHA1 != nil && *input.ChecksumSHA1 != sum { - return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) - } + gotSum = input.ChecksumSHA1 checksum.SHA1 = &sum sha1 = &sum case types.ChecksumAlgorithmSha256: - if input.ChecksumSHA256 != nil && *input.ChecksumSHA256 != sum { - return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) - } + gotSum = input.ChecksumSHA256 checksum.SHA256 = &sum sha256 = &sum case types.ChecksumAlgorithmCrc64nvme: - if input.ChecksumCRC64NVME != nil && *input.ChecksumCRC64NVME != sum { - return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) - } + gotSum = input.ChecksumCRC64NVME checksum.CRC64NVME = &sum crc64nvme = &sum } + + // Check if the provided checksum and the calculated one are the same + if gotSum != nil { + s := *gotSum + if checksums.Type == types.ChecksumTypeComposite && !strings.Contains(s, "-") { + // if number of parts is not specified in the final checksum + // make sure to add, to not fail in the final comparison + s = fmt.Sprintf("%s-%v", s, len(parts)) + } + + if s != sum { + return res, "", s3err.GetChecksumBadDigestErr(checksumAlgorithm) + } + } + err := p.storeChecksums(f.File(), bucket, object, checksum) if err != nil { return res, "", fmt.Errorf("store object checksum: %w", err) diff --git a/s3api/controllers/object-post.go b/s3api/controllers/object-post.go index cdb6801..be4e363 100644 --- a/s3api/controllers/object-post.go +++ b/s3api/controllers/object-post.go @@ -305,7 +305,7 @@ func (c S3ApiController) CompleteMultipartUpload(ctx *fiber.Ctx) (*Response, err mpuObjectSize = &val } - checksums, err := utils.ParseChecksumHeaders(ctx) + checksums, err := utils.ParseCompleteMpChecksumHeaders(ctx) if err != nil { return &Response{ MetaOpts: &MetaOptions{ diff --git a/s3api/utils/utils.go b/s3api/utils/utils.go index 9b2072d..dba6da8 100644 --- a/s3api/utils/utils.go +++ b/s3api/utils/utils.go @@ -461,6 +461,41 @@ func ParseCalculatedChecksumHeaders(ctx *fiber.Ctx) (ChecksumValues, error) { return checksums, nil } +// ParseCompleteMpChecksumHeaders parses and validates +// the 'CompleteMultipartUpload' x-amz-checksum-x headers +// by supporting both 'checksum' and 'checksum-' formats +func ParseCompleteMpChecksumHeaders(ctx *fiber.Ctx) (ChecksumValues, error) { + // first parse/validate 'x-amz-checksum-x' headers + checksums, err := ParseCalculatedChecksumHeaders(ctx) + if err != nil { + return checksums, err + } + + for al, val := range checksums { + algo := strings.ToLower(string(al)) + if al != types.ChecksumAlgorithmCrc64nvme { + chParts := strings.Split(val, "-") + if len(chParts) > 2 { + debuglogger.Logf("invalid checksum header: x-amz-checksum-%s: %s", algo, val) + return checksums, s3err.GetInvalidChecksumHeaderErr(fmt.Sprintf("x-amz-checksum-%v", algo)) + } + if len(chParts) == 2 { + _, err := strconv.ParseInt(chParts[1], 10, 32) + if err != nil { + debuglogger.Logf("invalid checksum header: x-amz-checksum-%s: %s", algo, val) + return checksums, s3err.GetInvalidChecksumHeaderErr(fmt.Sprintf("x-amz-checksum-%v", algo)) + } + val = chParts[0] + } + } + if !IsValidChecksum(val, al) { + return checksums, s3err.GetInvalidChecksumHeaderErr(fmt.Sprintf("x-amz-checksum-%v", algo)) + } + } + + return checksums, nil +} + // ParseChecksumHeaders parses/validates x-amz-checksum-x headers key/values func ParseChecksumHeaders(ctx *fiber.Ctx) (ChecksumValues, error) { // first parse/validate 'x-amz-checksum-x' headers diff --git a/tests/integration/group-tests.go b/tests/integration/group-tests.go index afde517..54e08f5 100644 --- a/tests/integration/group-tests.go +++ b/tests/integration/group-tests.go @@ -480,6 +480,8 @@ func TestCompleteMultipartUpload(ts *TestState) { ts.Run(CompleteMultipartUpload_incorrect_final_checksums) ts.Run(CompleteMultipartUpload_should_calculate_the_final_checksum_full_object) ts.Run(CompleteMultipartUpload_should_verify_the_final_checksum) + ts.Run(CompleteMultipartUpload_should_verify_final_composite_checksum) + ts.Run(CompleteMultipartUpload_invalid_final_composite_checksum) ts.Run(CompleteMultipartUpload_checksum_type_mismatch) ts.Run(CompleteMultipartUpload_should_ignore_the_final_checksum) ts.Run(CompleteMultipartUpload_should_succeed_without_final_checksum_type) @@ -1374,6 +1376,8 @@ func GetIntTests() IntTests { "CompleteMultipartUpload_incorrect_final_checksums": CompleteMultipartUpload_incorrect_final_checksums, "CompleteMultipartUpload_should_calculate_the_final_checksum_full_object": CompleteMultipartUpload_should_calculate_the_final_checksum_full_object, "CompleteMultipartUpload_should_verify_the_final_checksum": CompleteMultipartUpload_should_verify_the_final_checksum, + "CompleteMultipartUpload_should_verify_final_composite_checksum": CompleteMultipartUpload_should_verify_final_composite_checksum, + "CompleteMultipartUpload_invalid_final_composite_checksum": CompleteMultipartUpload_invalid_final_composite_checksum, "CompleteMultipartUpload_checksum_type_mismatch": CompleteMultipartUpload_checksum_type_mismatch, "CompleteMultipartUpload_should_ignore_the_final_checksum": CompleteMultipartUpload_should_ignore_the_final_checksum, "CompleteMultipartUpload_should_succeed_without_final_checksum_type": CompleteMultipartUpload_should_succeed_without_final_checksum_type, diff --git a/tests/integration/tests.go b/tests/integration/tests.go index 2054375..bf9f040 100644 --- a/tests/integration/tests.go +++ b/tests/integration/tests.go @@ -12895,6 +12895,175 @@ func CompleteMultipartUpload_should_verify_the_final_checksum(s *S3Conf) error { }) } +func CompleteMultipartUpload_should_verify_final_composite_checksum(s *S3Conf) error { + testName := "CompleteMultipartUpload_should_verify_final_composite_checksum" + return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { + obj := "my-obj" + for i, algo := range []types.ChecksumAlgorithm{ + types.ChecksumAlgorithmCrc32, + types.ChecksumAlgorithmCrc32c, + types.ChecksumAlgorithmSha1, + types.ChecksumAlgorithmSha256, + } { + mp, err := createMp(s3client, bucket, obj, withChecksumType(types.ChecksumTypeComposite), withChecksum(algo)) + if err != nil { + return fmt.Errorf("test %v failed: %s", i, err) + } + + parts, _, err := uploadParts(s3client, 25*1024*1024, 5, bucket, obj, *mp.UploadId, withChecksum(algo)) + if err != nil { + return fmt.Errorf("test %v failed: %s", i, err) + } + + hasher, err := NewHasher(algo) + if err != nil { + return fmt.Errorf("test %v failed: %s", i, err) + } + + completeParts := make([]types.CompletedPart, 0, len(parts)) + + for _, part := range parts { + switch algo { + case types.ChecksumAlgorithmCrc32: + err = processCompositeChecksum(hasher, getString(part.ChecksumCRC32)) + case types.ChecksumAlgorithmCrc32c: + err = processCompositeChecksum(hasher, getString(part.ChecksumCRC32C)) + case types.ChecksumAlgorithmSha1: + err = processCompositeChecksum(hasher, getString(part.ChecksumSHA1)) + case types.ChecksumAlgorithmSha256: + err = processCompositeChecksum(hasher, getString(part.ChecksumSHA256)) + } + + if err != nil { + return fmt.Errorf("test %v failed: %s", i, err) + } + + completeParts = append(completeParts, types.CompletedPart{ + ETag: part.ETag, + PartNumber: part.PartNumber, + ChecksumCRC32: part.ChecksumCRC32, + ChecksumCRC32C: part.ChecksumCRC32C, + ChecksumSHA1: part.ChecksumSHA1, + ChecksumSHA256: part.ChecksumSHA256, + }) + } + + checksum := fmt.Sprintf("%s-%v", base64.StdEncoding.EncodeToString(hasher.Sum(nil)), len(parts)) + + completeMpInput := &s3.CompleteMultipartUploadInput{ + Bucket: &bucket, + Key: &obj, + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completeParts, + }, + UploadId: mp.UploadId, + } + + switch algo { + case types.ChecksumAlgorithmCrc32: + completeMpInput.ChecksumCRC32 = &checksum + case types.ChecksumAlgorithmCrc32c: + completeMpInput.ChecksumCRC32C = &checksum + case types.ChecksumAlgorithmSha1: + completeMpInput.ChecksumSHA1 = &checksum + case types.ChecksumAlgorithmSha256: + completeMpInput.ChecksumSHA256 = &checksum + } + + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + res, err := s3client.CompleteMultipartUpload(ctx, completeMpInput) + cancel() + if err != nil { + return fmt.Errorf("test %v failed: %s", i, err) + } + + var gotSum string + switch algo { + case types.ChecksumAlgorithmCrc32: + gotSum = getString(res.ChecksumCRC32) + case types.ChecksumAlgorithmCrc32c: + gotSum = getString(res.ChecksumCRC32C) + case types.ChecksumAlgorithmSha1: + gotSum = getString(res.ChecksumSHA1) + case types.ChecksumAlgorithmSha256: + gotSum = getString(res.ChecksumSHA256) + } + + if gotSum != checksum { + return fmt.Errorf("test %v failed: expected the final checksum to be %s, instead got %s", i, checksum, gotSum) + } + } + + return nil + }) +} + +func CompleteMultipartUpload_invalid_final_composite_checksum(s *S3Conf) error { + testName := "CompleteMultipartUpload_invalid_final_composite_checksum" + return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { + obj := "my-obj" + for i, test := range []struct { + algo types.ChecksumAlgorithm + crc32 *string + crc32c *string + sha1 *string + sha256 *string + }{ + {types.ChecksumAlgorithmCrc32, getPtr("invalid_checksum"), nil, nil, nil}, + {types.ChecksumAlgorithmCrc32, getPtr("ImIEBA==-smth"), nil, nil, nil}, + {types.ChecksumAlgorithmCrc32c, nil, getPtr("invalid_checksum"), nil, nil}, + {types.ChecksumAlgorithmCrc32c, nil, getPtr("AQIDBA==-12a"), nil, nil}, + {types.ChecksumAlgorithmSha1, nil, nil, getPtr("invalid_checksum"), nil}, + {types.ChecksumAlgorithmSha1, nil, nil, getPtr("2jmj7l5rSw0yVb/vlWAYkK/YBwk=-10-20"), nil}, + {types.ChecksumAlgorithmSha256, nil, nil, nil, getPtr("invalid_checksum")}, + {types.ChecksumAlgorithmSha256, nil, nil, nil, getPtr("47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=--3")}, + } { + mp, err := createMp(s3client, bucket, obj, withChecksum(test.algo), withChecksumType(types.ChecksumTypeComposite)) + if err != nil { + return fmt.Errorf("test %v failed: %w", i, err) + } + + parts, _, err := uploadParts(s3client, 5*1024*1024, 1, bucket, obj, *mp.UploadId, withChecksum(test.algo)) + if err != nil { + return fmt.Errorf("test %v failed: %w", i, err) + } + + completeParts := make([]types.CompletedPart, 0, len(parts)) + + for _, part := range parts { + completeParts = append(completeParts, types.CompletedPart{ + ETag: part.ETag, + PartNumber: part.PartNumber, + ChecksumCRC32: part.ChecksumCRC32, + ChecksumCRC32C: part.ChecksumCRC32C, + ChecksumSHA1: part.ChecksumSHA1, + ChecksumSHA256: part.ChecksumSHA256, + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), shortTimeout) + _, err = s3client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: &bucket, + Key: &obj, + UploadId: mp.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completeParts, + }, + ChecksumCRC32: test.crc32, + ChecksumCRC32C: test.crc32c, + ChecksumSHA1: test.sha1, + ChecksumSHA256: test.sha256, + }) + cancel() + if err := checkApiErr(err, s3err.GetInvalidChecksumHeaderErr(fmt.Sprintf("x-amz-checksum-%v", strings.ToLower(string(test.algo))))); err != nil { + return fmt.Errorf("test %v failed: %w", i, err) + } + } + + return nil + }) +} + func CompleteMultipartUpload_checksum_type_mismatch(s *S3Conf) error { testName := "CompleteMultipartUpload_checksum_type_mismatch" return actionHandler(s, testName, func(s3client *s3.Client, bucket string) error { diff --git a/tests/integration/utils.go b/tests/integration/utils.go index 0010c09..807c7c7 100644 --- a/tests/integration/utils.go +++ b/tests/integration/utils.go @@ -1957,3 +1957,35 @@ func lockObject(client *s3.Client, mode objectLockMode, bucket, object, versionI }) return err } + +func NewHasher(algo types.ChecksumAlgorithm) (hash.Hash, error) { + var hasher hash.Hash + switch algo { + case types.ChecksumAlgorithmSha256: + hasher = sha256.New() + case types.ChecksumAlgorithmSha1: + hasher = sha1.New() + case types.ChecksumAlgorithmCrc32: + hasher = crc32.NewIEEE() + case types.ChecksumAlgorithmCrc32c: + hasher = crc32.New(crc32.MakeTable(crc32.Castagnoli)) + default: + return nil, fmt.Errorf("unsupported hash algorithm: %s", algo) + } + + return hasher, nil +} + +func processCompositeChecksum(hasher hash.Hash, checksum string) error { + data, err := base64.StdEncoding.DecodeString(checksum) + if err != nil { + return fmt.Errorf("base64 decode: %w", err) + } + + _, err = hasher.Write(data) + if err != nil { + return fmt.Errorf("hash write: %w", err) + } + + return nil +}