From 173518278e5d4f5f09014834edc10da67a1b0dbd Mon Sep 17 00:00:00 2001 From: niksis02 Date: Wed, 19 Feb 2025 23:59:34 +0400 Subject: [PATCH] fix: refactoring the checksum implementation by avoiding many if conditions and making the code more readable --- backend/posix/posix.go | 135 ++++++++++++++------------------------ s3api/controllers/base.go | 83 +++++++++++------------ 2 files changed, 85 insertions(+), 133 deletions(-) diff --git a/backend/posix/posix.go b/backend/posix/posix.go index 5b3ac88..6e4d2d4 100644 --- a/backend/posix/posix.go +++ b/backend/posix/posix.go @@ -1405,7 +1405,7 @@ func (p *Posix) CompleteMultipartUpload(ctx context.Context, input *s3.CompleteM objdir := filepath.Join(metaTmpMultipartDir, fmt.Sprintf("%x", sum)) - checksums, err := p.retreiveChecksums(nil, bucket, filepath.Join(objdir, uploadID)) + checksums, err := p.retrieveChecksums(nil, bucket, filepath.Join(objdir, uploadID)) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get mp checksums: %w", err) } @@ -1466,7 +1466,7 @@ func (p *Posix) CompleteMultipartUpload(ctx context.Context, input *s3.CompleteM return nil, s3err.GetAPIError(s3err.ErrInvalidPart) } - partChecksum, err := p.retreiveChecksums(nil, bucket, partObjPath) + partChecksum, err := p.retrieveChecksums(nil, bucket, partObjPath) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get part checksum: %w", err) } @@ -1992,7 +1992,7 @@ func (p *Posix) ListMultipartUploads(_ context.Context, mpu *s3.ListMultipartUpl keyMarkerInd = len(uploads) } - checksum, err := p.retreiveChecksums(nil, bucket, filepath.Join(metaTmpMultipartDir, obj.Name(), uploadID)) + checksum, err := p.retrieveChecksums(nil, bucket, filepath.Join(metaTmpMultipartDir, obj.Name(), uploadID)) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return lmu, fmt.Errorf("get mp checksum: %w", err) } @@ -2122,7 +2122,7 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon return lpr, fmt.Errorf("readdir upload: %w", err) } - checksum, err := p.retreiveChecksums(nil, tmpdir, uploadID) + checksum, err := p.retrieveChecksums(nil, tmpdir, uploadID) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return lpr, fmt.Errorf("get mp checksum: %w", err) } @@ -2145,7 +2145,7 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon etag = "" } - checksum, err := p.retreiveChecksums(nil, bucket, partPath) + checksum, err := p.retrieveChecksums(nil, bucket, partPath) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { continue } @@ -2201,6 +2201,11 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon }, nil } +type hashConfig struct { + value *string + hashType utils.HashType +} + func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3.UploadPartOutput, error) { acct, ok := ctx.Value("account").(auth.Account) if !ok { @@ -2259,46 +2264,24 @@ func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3. hash := md5.New() tr := io.TeeReader(r, hash) + hashConfigs := []hashConfig{ + {input.ChecksumCRC32, utils.HashTypeCRC32}, + {input.ChecksumCRC32C, utils.HashTypeCRC32C}, + {input.ChecksumSHA1, utils.HashTypeSha1}, + {input.ChecksumSHA256, utils.HashTypeSha256}, + {input.ChecksumCRC64NVME, utils.HashTypeCRC64NVME}, + } + var hashRdr *utils.HashReader - if input.ChecksumCRC32 != nil { - hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC32, utils.HashTypeCRC32) - if err != nil { - return nil, fmt.Errorf("initialize hash reader: %w", err) - } + for _, config := range hashConfigs { + if config.value != nil { + hashRdr, err = utils.NewHashReader(tr, *config.value, config.hashType) + if err != nil { + return nil, fmt.Errorf("initialize hash reader: %w", err) + } - tr = hashRdr - } - if input.ChecksumCRC32C != nil { - hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC32C, utils.HashTypeCRC32C) - if err != nil { - return nil, fmt.Errorf("initialize hash reader: %w", err) + tr = hashRdr } - - tr = hashRdr - } - if input.ChecksumSHA1 != nil { - hashRdr, err = utils.NewHashReader(tr, *input.ChecksumSHA1, utils.HashTypeSha1) - if err != nil { - return nil, fmt.Errorf("initialize hash reader: %w", err) - } - - tr = hashRdr - } - if input.ChecksumSHA256 != nil { - hashRdr, err = utils.NewHashReader(tr, *input.ChecksumSHA256, utils.HashTypeSha256) - if err != nil { - return nil, fmt.Errorf("initialize hash reader: %w", err) - } - - tr = hashRdr - } - if input.ChecksumCRC64NVME != nil { - hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC64NVME, utils.HashTypeCRC64NVME) - if err != nil { - return nil, fmt.Errorf("initialize hash reader: %w", err) - } - - tr = hashRdr } // If only the checksum algorithm is provided register @@ -2312,7 +2295,7 @@ func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3. tr = hashRdr } - checksums, chErr := p.retreiveChecksums(nil, bucket, mpPath) + checksums, chErr := p.retrieveChecksums(nil, bucket, mpPath) if chErr != nil && !errors.Is(chErr, meta.ErrNoSuchKey) { return nil, fmt.Errorf("retreive mp checksum: %w", chErr) } @@ -2519,12 +2502,12 @@ func (p *Posix) UploadPartCopy(ctx context.Context, upi *s3.UploadPartCopyInput) hash := md5.New() tr := io.TeeReader(rdr, hash) - mpChecksums, err := p.retreiveChecksums(nil, *upi.Bucket, filepath.Join(objdir, *upi.UploadId)) + mpChecksums, err := p.retrieveChecksums(nil, *upi.Bucket, filepath.Join(objdir, *upi.UploadId)) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return s3response.CopyPartResult{}, fmt.Errorf("retreive mp checksums: %w", err) } - checksums, err := p.retreiveChecksums(nil, objPath, "") + checksums, err := p.retrieveChecksums(nil, objPath, "") if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return s3response.CopyPartResult{}, fmt.Errorf("retreive object part checksums: %w", err) } @@ -2752,46 +2735,24 @@ func (p *Posix) PutObject(ctx context.Context, po *s3.PutObjectInput) (s3respons hash := md5.New() rdr := io.TeeReader(po.Body, hash) + hashConfigs := []hashConfig{ + {po.ChecksumCRC32, utils.HashTypeCRC32}, + {po.ChecksumCRC32C, utils.HashTypeCRC32C}, + {po.ChecksumSHA1, utils.HashTypeSha1}, + {po.ChecksumSHA256, utils.HashTypeSha256}, + {po.ChecksumCRC64NVME, utils.HashTypeCRC64NVME}, + } var hashRdr *utils.HashReader - if po.ChecksumCRC32 != nil { - hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC32, utils.HashTypeCRC32) - if err != nil { - return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) - } - rdr = hashRdr - } - if po.ChecksumCRC32C != nil { - hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC32C, utils.HashTypeCRC32C) - if err != nil { - return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) - } + for _, config := range hashConfigs { + if config.value != nil { + hashRdr, err = utils.NewHashReader(rdr, *config.value, config.hashType) + if err != nil { + return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) + } - rdr = hashRdr - } - if po.ChecksumSHA1 != nil { - hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumSHA1, utils.HashTypeSha1) - if err != nil { - return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) + rdr = hashRdr } - - rdr = hashRdr - } - if po.ChecksumSHA256 != nil { - hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumSHA256, utils.HashTypeSha256) - if err != nil { - return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) - } - - rdr = hashRdr - } - if po.ChecksumCRC64NVME != nil { - hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC64NVME, utils.HashTypeCRC64NVME) - if err != nil { - return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err) - } - - rdr = hashRdr } // If only the checksum algorithm is provided register @@ -3508,7 +3469,7 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput) (*s3.GetO var cType types.ChecksumType // Skip the checksums retreival if object isn't requested fully if input.ChecksumMode == types.ChecksumModeEnabled && length-startOffset == objSize { - checksums, err = p.retreiveChecksums(f, bucket, object) + checksums, err = p.retrieveChecksums(f, bucket, object) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get object checksums: %w", err) } @@ -3719,7 +3680,7 @@ func (p *Posix) HeadObject(ctx context.Context, input *s3.HeadObjectInput) (*s3. var checksums s3response.Checksum var cType types.ChecksumType if input.ChecksumMode == types.ChecksumModeEnabled { - checksums, err = p.retreiveChecksums(nil, bucket, object) + checksums, err = p.retrieveChecksums(nil, bucket, object) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get object checksums: %w", err) } @@ -3905,7 +3866,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3. } } - checksums, err := p.retreiveChecksums(nil, dstBucket, dstObject) + checksums, err := p.retrieveChecksums(nil, dstBucket, dstObject) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get obj checksums: %w", err) } @@ -3968,7 +3929,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3. } else { contentLength := fi.Size() - checksums, err := p.retreiveChecksums(f, srcBucket, srcObject) + checksums, err := p.retrieveChecksums(f, srcBucket, srcObject) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return nil, fmt.Errorf("get obj checksum: %w", err) } @@ -4111,7 +4072,7 @@ func (p *Posix) fileToObj(bucket string) backend.GetObjFunc { } // Retreive the object checksum algorithm - checksums, err := p.retreiveChecksums(nil, bucket, path) + checksums, err := p.retrieveChecksums(nil, bucket, path) if err != nil && !errors.Is(err, meta.ErrNoSuchKey) { return s3response.Object{}, backend.ErrSkipObj } @@ -4787,7 +4748,7 @@ func (p *Posix) storeChecksums(f *os.File, bucket, object string, chs s3response return p.meta.StoreAttribute(f, bucket, object, checksumsKey, checksums) } -func (p *Posix) retreiveChecksums(f *os.File, bucket, object string) (checksums s3response.Checksum, err error) { +func (p *Posix) retrieveChecksums(f *os.File, bucket, object string) (checksums s3response.Checksum, err error) { checksumsAtr, err := p.meta.RetrieveAttribute(f, bucket, object, checksumsKey) if err != nil { return checksums, err diff --git a/s3api/controllers/base.go b/s3api/controllers/base.go index 9a496fb..556be04 100644 --- a/s3api/controllers/base.go +++ b/s3api/controllers/base.go @@ -2101,36 +2101,33 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { Value: *res.ETag, }) } - if res.ChecksumCRC32 != nil { + switch { + case res.ChecksumCRC32 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-crc32", Value: *res.ChecksumCRC32, }) - } - if res.ChecksumCRC32C != nil { + case res.ChecksumCRC32C != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-crc32c", Value: *res.ChecksumCRC32C, }) - } - if res.ChecksumSHA1 != nil { + case res.ChecksumCRC64NVME != nil: + headers = append(headers, utils.CustomHeader{ + Key: "x-amz-checksum-crc64nvme", + Value: *res.ChecksumCRC64NVME, + }) + case res.ChecksumSHA1 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-sha1", Value: *res.ChecksumSHA1, }) - } - if res.ChecksumSHA256 != nil { + case res.ChecksumSHA256 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-sha256", Value: *res.ChecksumSHA256, }) } - if res.ChecksumCRC64NVME != nil { - headers = append(headers, utils.CustomHeader{ - Key: "x-amz-checksum-crc64nvme", - Value: *res.ChecksumCRC64NVME, - }) - } utils.SetResponseHeaders(ctx, headers) } @@ -2558,34 +2555,31 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { Value: res.VersionID, }) } - if getstring(res.ChecksumCRC32) != "" { + switch { + case res.ChecksumCRC32 != nil: hdrs = append(hdrs, utils.CustomHeader{ Key: "x-amz-checksum-crc32", - Value: getstring(res.ChecksumCRC32), + Value: *res.ChecksumCRC32, }) - } - if getstring(res.ChecksumCRC32C) != "" { + case res.ChecksumCRC32C != nil: hdrs = append(hdrs, utils.CustomHeader{ Key: "x-amz-checksum-crc32c", - Value: getstring(res.ChecksumCRC32C), + Value: *res.ChecksumCRC32C, }) - } - if getstring(res.ChecksumSHA1) != "" { - hdrs = append(hdrs, utils.CustomHeader{ - Key: "x-amz-checksum-sha1", - Value: getstring(res.ChecksumSHA1), - }) - } - if getstring(res.ChecksumSHA256) != "" { - hdrs = append(hdrs, utils.CustomHeader{ - Key: "x-amz-checksum-sha256", - Value: getstring(res.ChecksumSHA256), - }) - } - if getstring(res.ChecksumCRC64NVME) != "" { + case res.ChecksumCRC64NVME != nil: hdrs = append(hdrs, utils.CustomHeader{ Key: "x-amz-checksum-crc64nvme", - Value: getstring(res.ChecksumCRC64NVME), + Value: *res.ChecksumCRC64NVME, + }) + case res.ChecksumSHA1 != nil: + hdrs = append(hdrs, utils.CustomHeader{ + Key: "x-amz-checksum-sha1", + Value: *res.ChecksumSHA1, + }) + case res.ChecksumSHA256 != nil: + hdrs = append(hdrs, utils.CustomHeader{ + Key: "x-amz-checksum-sha256", + Value: *res.ChecksumSHA256, }) } if res.ChecksumType != "" { @@ -3214,36 +3208,33 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error { Value: string(res.StorageClass), }) } - if res.ChecksumCRC32 != nil { + switch { + case res.ChecksumCRC32 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-crc32", Value: *res.ChecksumCRC32, }) - } - if res.ChecksumCRC32C != nil { + case res.ChecksumCRC32C != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-crc32c", Value: *res.ChecksumCRC32C, }) - } - if res.ChecksumSHA1 != nil { + case res.ChecksumCRC64NVME != nil: + headers = append(headers, utils.CustomHeader{ + Key: "x-amz-checksum-crc64nvme", + Value: *res.ChecksumCRC64NVME, + }) + case res.ChecksumSHA1 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-sha1", Value: *res.ChecksumSHA1, }) - } - if res.ChecksumSHA256 != nil { + case res.ChecksumSHA256 != nil: headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-sha256", Value: *res.ChecksumSHA256, }) } - if res.ChecksumCRC64NVME != nil { - headers = append(headers, utils.CustomHeader{ - Key: "x-amz-checksum-crc64nvme", - Value: *res.ChecksumCRC64NVME, - }) - } if res.ChecksumType != "" { headers = append(headers, utils.CustomHeader{ Key: "x-amz-checksum-type",