fix: refactoring the checksum implementation by avoiding many if conditions and making the code more readable

This commit is contained in:
niksis02
2025-02-19 23:59:34 +04:00
parent ff0cf29d0a
commit 173518278e
2 changed files with 85 additions and 133 deletions

View File

@@ -1405,7 +1405,7 @@ func (p *Posix) CompleteMultipartUpload(ctx context.Context, input *s3.CompleteM
objdir := filepath.Join(metaTmpMultipartDir, fmt.Sprintf("%x", sum))
checksums, err := p.retreiveChecksums(nil, bucket, filepath.Join(objdir, uploadID))
checksums, err := p.retrieveChecksums(nil, bucket, filepath.Join(objdir, uploadID))
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get mp checksums: %w", err)
}
@@ -1466,7 +1466,7 @@ func (p *Posix) CompleteMultipartUpload(ctx context.Context, input *s3.CompleteM
return nil, s3err.GetAPIError(s3err.ErrInvalidPart)
}
partChecksum, err := p.retreiveChecksums(nil, bucket, partObjPath)
partChecksum, err := p.retrieveChecksums(nil, bucket, partObjPath)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get part checksum: %w", err)
}
@@ -1992,7 +1992,7 @@ func (p *Posix) ListMultipartUploads(_ context.Context, mpu *s3.ListMultipartUpl
keyMarkerInd = len(uploads)
}
checksum, err := p.retreiveChecksums(nil, bucket, filepath.Join(metaTmpMultipartDir, obj.Name(), uploadID))
checksum, err := p.retrieveChecksums(nil, bucket, filepath.Join(metaTmpMultipartDir, obj.Name(), uploadID))
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return lmu, fmt.Errorf("get mp checksum: %w", err)
}
@@ -2122,7 +2122,7 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon
return lpr, fmt.Errorf("readdir upload: %w", err)
}
checksum, err := p.retreiveChecksums(nil, tmpdir, uploadID)
checksum, err := p.retrieveChecksums(nil, tmpdir, uploadID)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return lpr, fmt.Errorf("get mp checksum: %w", err)
}
@@ -2145,7 +2145,7 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon
etag = ""
}
checksum, err := p.retreiveChecksums(nil, bucket, partPath)
checksum, err := p.retrieveChecksums(nil, bucket, partPath)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
continue
}
@@ -2201,6 +2201,11 @@ func (p *Posix) ListParts(_ context.Context, input *s3.ListPartsInput) (s3respon
}, nil
}
type hashConfig struct {
value *string
hashType utils.HashType
}
func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
acct, ok := ctx.Value("account").(auth.Account)
if !ok {
@@ -2259,46 +2264,24 @@ func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3.
hash := md5.New()
tr := io.TeeReader(r, hash)
hashConfigs := []hashConfig{
{input.ChecksumCRC32, utils.HashTypeCRC32},
{input.ChecksumCRC32C, utils.HashTypeCRC32C},
{input.ChecksumSHA1, utils.HashTypeSha1},
{input.ChecksumSHA256, utils.HashTypeSha256},
{input.ChecksumCRC64NVME, utils.HashTypeCRC64NVME},
}
var hashRdr *utils.HashReader
if input.ChecksumCRC32 != nil {
hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC32, utils.HashTypeCRC32)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
}
for _, config := range hashConfigs {
if config.value != nil {
hashRdr, err = utils.NewHashReader(tr, *config.value, config.hashType)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
}
tr = hashRdr
}
if input.ChecksumCRC32C != nil {
hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC32C, utils.HashTypeCRC32C)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
tr = hashRdr
}
tr = hashRdr
}
if input.ChecksumSHA1 != nil {
hashRdr, err = utils.NewHashReader(tr, *input.ChecksumSHA1, utils.HashTypeSha1)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
}
tr = hashRdr
}
if input.ChecksumSHA256 != nil {
hashRdr, err = utils.NewHashReader(tr, *input.ChecksumSHA256, utils.HashTypeSha256)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
}
tr = hashRdr
}
if input.ChecksumCRC64NVME != nil {
hashRdr, err = utils.NewHashReader(tr, *input.ChecksumCRC64NVME, utils.HashTypeCRC64NVME)
if err != nil {
return nil, fmt.Errorf("initialize hash reader: %w", err)
}
tr = hashRdr
}
// If only the checksum algorithm is provided register
@@ -2312,7 +2295,7 @@ func (p *Posix) UploadPart(ctx context.Context, input *s3.UploadPartInput) (*s3.
tr = hashRdr
}
checksums, chErr := p.retreiveChecksums(nil, bucket, mpPath)
checksums, chErr := p.retrieveChecksums(nil, bucket, mpPath)
if chErr != nil && !errors.Is(chErr, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("retreive mp checksum: %w", chErr)
}
@@ -2519,12 +2502,12 @@ func (p *Posix) UploadPartCopy(ctx context.Context, upi *s3.UploadPartCopyInput)
hash := md5.New()
tr := io.TeeReader(rdr, hash)
mpChecksums, err := p.retreiveChecksums(nil, *upi.Bucket, filepath.Join(objdir, *upi.UploadId))
mpChecksums, err := p.retrieveChecksums(nil, *upi.Bucket, filepath.Join(objdir, *upi.UploadId))
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return s3response.CopyPartResult{}, fmt.Errorf("retreive mp checksums: %w", err)
}
checksums, err := p.retreiveChecksums(nil, objPath, "")
checksums, err := p.retrieveChecksums(nil, objPath, "")
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return s3response.CopyPartResult{}, fmt.Errorf("retreive object part checksums: %w", err)
}
@@ -2752,46 +2735,24 @@ func (p *Posix) PutObject(ctx context.Context, po *s3.PutObjectInput) (s3respons
hash := md5.New()
rdr := io.TeeReader(po.Body, hash)
hashConfigs := []hashConfig{
{po.ChecksumCRC32, utils.HashTypeCRC32},
{po.ChecksumCRC32C, utils.HashTypeCRC32C},
{po.ChecksumSHA1, utils.HashTypeSha1},
{po.ChecksumSHA256, utils.HashTypeSha256},
{po.ChecksumCRC64NVME, utils.HashTypeCRC64NVME},
}
var hashRdr *utils.HashReader
if po.ChecksumCRC32 != nil {
hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC32, utils.HashTypeCRC32)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
}
rdr = hashRdr
}
if po.ChecksumCRC32C != nil {
hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC32C, utils.HashTypeCRC32C)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
}
for _, config := range hashConfigs {
if config.value != nil {
hashRdr, err = utils.NewHashReader(rdr, *config.value, config.hashType)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
}
rdr = hashRdr
}
if po.ChecksumSHA1 != nil {
hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumSHA1, utils.HashTypeSha1)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
rdr = hashRdr
}
rdr = hashRdr
}
if po.ChecksumSHA256 != nil {
hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumSHA256, utils.HashTypeSha256)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
}
rdr = hashRdr
}
if po.ChecksumCRC64NVME != nil {
hashRdr, err = utils.NewHashReader(rdr, *po.ChecksumCRC64NVME, utils.HashTypeCRC64NVME)
if err != nil {
return s3response.PutObjectOutput{}, fmt.Errorf("initialize hash reader: %w", err)
}
rdr = hashRdr
}
// If only the checksum algorithm is provided register
@@ -3508,7 +3469,7 @@ func (p *Posix) GetObject(_ context.Context, input *s3.GetObjectInput) (*s3.GetO
var cType types.ChecksumType
// Skip the checksums retreival if object isn't requested fully
if input.ChecksumMode == types.ChecksumModeEnabled && length-startOffset == objSize {
checksums, err = p.retreiveChecksums(f, bucket, object)
checksums, err = p.retrieveChecksums(f, bucket, object)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get object checksums: %w", err)
}
@@ -3719,7 +3680,7 @@ func (p *Posix) HeadObject(ctx context.Context, input *s3.HeadObjectInput) (*s3.
var checksums s3response.Checksum
var cType types.ChecksumType
if input.ChecksumMode == types.ChecksumModeEnabled {
checksums, err = p.retreiveChecksums(nil, bucket, object)
checksums, err = p.retrieveChecksums(nil, bucket, object)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get object checksums: %w", err)
}
@@ -3905,7 +3866,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3.
}
}
checksums, err := p.retreiveChecksums(nil, dstBucket, dstObject)
checksums, err := p.retrieveChecksums(nil, dstBucket, dstObject)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get obj checksums: %w", err)
}
@@ -3968,7 +3929,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3.
} else {
contentLength := fi.Size()
checksums, err := p.retreiveChecksums(f, srcBucket, srcObject)
checksums, err := p.retrieveChecksums(f, srcBucket, srcObject)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return nil, fmt.Errorf("get obj checksum: %w", err)
}
@@ -4111,7 +4072,7 @@ func (p *Posix) fileToObj(bucket string) backend.GetObjFunc {
}
// Retreive the object checksum algorithm
checksums, err := p.retreiveChecksums(nil, bucket, path)
checksums, err := p.retrieveChecksums(nil, bucket, path)
if err != nil && !errors.Is(err, meta.ErrNoSuchKey) {
return s3response.Object{}, backend.ErrSkipObj
}
@@ -4787,7 +4748,7 @@ func (p *Posix) storeChecksums(f *os.File, bucket, object string, chs s3response
return p.meta.StoreAttribute(f, bucket, object, checksumsKey, checksums)
}
func (p *Posix) retreiveChecksums(f *os.File, bucket, object string) (checksums s3response.Checksum, err error) {
func (p *Posix) retrieveChecksums(f *os.File, bucket, object string) (checksums s3response.Checksum, err error) {
checksumsAtr, err := p.meta.RetrieveAttribute(f, bucket, object, checksumsKey)
if err != nil {
return checksums, err

View File

@@ -2101,36 +2101,33 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
Value: *res.ETag,
})
}
if res.ChecksumCRC32 != nil {
switch {
case res.ChecksumCRC32 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc32",
Value: *res.ChecksumCRC32,
})
}
if res.ChecksumCRC32C != nil {
case res.ChecksumCRC32C != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc32c",
Value: *res.ChecksumCRC32C,
})
}
if res.ChecksumSHA1 != nil {
case res.ChecksumCRC64NVME != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc64nvme",
Value: *res.ChecksumCRC64NVME,
})
case res.ChecksumSHA1 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-sha1",
Value: *res.ChecksumSHA1,
})
}
if res.ChecksumSHA256 != nil {
case res.ChecksumSHA256 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-sha256",
Value: *res.ChecksumSHA256,
})
}
if res.ChecksumCRC64NVME != nil {
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc64nvme",
Value: *res.ChecksumCRC64NVME,
})
}
utils.SetResponseHeaders(ctx, headers)
}
@@ -2558,34 +2555,31 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
Value: res.VersionID,
})
}
if getstring(res.ChecksumCRC32) != "" {
switch {
case res.ChecksumCRC32 != nil:
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-crc32",
Value: getstring(res.ChecksumCRC32),
Value: *res.ChecksumCRC32,
})
}
if getstring(res.ChecksumCRC32C) != "" {
case res.ChecksumCRC32C != nil:
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-crc32c",
Value: getstring(res.ChecksumCRC32C),
Value: *res.ChecksumCRC32C,
})
}
if getstring(res.ChecksumSHA1) != "" {
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-sha1",
Value: getstring(res.ChecksumSHA1),
})
}
if getstring(res.ChecksumSHA256) != "" {
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-sha256",
Value: getstring(res.ChecksumSHA256),
})
}
if getstring(res.ChecksumCRC64NVME) != "" {
case res.ChecksumCRC64NVME != nil:
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-crc64nvme",
Value: getstring(res.ChecksumCRC64NVME),
Value: *res.ChecksumCRC64NVME,
})
case res.ChecksumSHA1 != nil:
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-sha1",
Value: *res.ChecksumSHA1,
})
case res.ChecksumSHA256 != nil:
hdrs = append(hdrs, utils.CustomHeader{
Key: "x-amz-checksum-sha256",
Value: *res.ChecksumSHA256,
})
}
if res.ChecksumType != "" {
@@ -3214,36 +3208,33 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error {
Value: string(res.StorageClass),
})
}
if res.ChecksumCRC32 != nil {
switch {
case res.ChecksumCRC32 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc32",
Value: *res.ChecksumCRC32,
})
}
if res.ChecksumCRC32C != nil {
case res.ChecksumCRC32C != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc32c",
Value: *res.ChecksumCRC32C,
})
}
if res.ChecksumSHA1 != nil {
case res.ChecksumCRC64NVME != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc64nvme",
Value: *res.ChecksumCRC64NVME,
})
case res.ChecksumSHA1 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-sha1",
Value: *res.ChecksumSHA1,
})
}
if res.ChecksumSHA256 != nil {
case res.ChecksumSHA256 != nil:
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-sha256",
Value: *res.ChecksumSHA256,
})
}
if res.ChecksumCRC64NVME != nil {
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-crc64nvme",
Value: *res.ChecksumCRC64NVME,
})
}
if res.ChecksumType != "" {
headers = append(headers, utils.CustomHeader{
Key: "x-amz-checksum-type",