diff --git a/s3api/utils/signed-chunk-reader.go b/s3api/utils/signed-chunk-reader.go index fd00b78..0610504 100644 --- a/s3api/utils/signed-chunk-reader.go +++ b/s3api/utils/signed-chunk-reader.go @@ -30,6 +30,7 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/versity/versitygw/s3err" ) @@ -190,7 +191,8 @@ func (cr *ChunkReader) verifyChecksum() error { checksumHash := cr.checksumHash.Sum(nil) checksum := base64.StdEncoding.EncodeToString(checksumHash) if checksum != cr.parsedChecksum { - return fmt.Errorf("actual checksum: %v, expected checksum: %v", checksum, cr.parsedChecksum) + algo := types.ChecksumAlgorithm(strings.ToUpper(strings.TrimPrefix(string(cr.trailer), "x-amz-checksum-"))) + return s3err.GetChecksumBadDigestErr(algo) } return nil @@ -380,12 +382,18 @@ func (cr *ChunkReader) parseChunkHeaderBytes(header []byte, l *int) (int64, stri return 0, "", 0, errInvalidChunkFormat } + algo := types.ChecksumAlgorithm(strings.ToUpper(strings.TrimPrefix(trailer, "x-amz-checksum-"))) + // parse the checksum checksum, err := readAndTrim(rdr, '\r') if err != nil { return cr.handleRdrErr(err, header) } + if !IsValidChecksum(checksum, algo) { + return 0, "", 0, s3err.GetInvalidTrailingChecksumHeaderErr(trailer) + } + err = readAndSkip(rdr, '\n') if err != nil { return cr.handleRdrErr(err, header) diff --git a/s3err/s3err.go b/s3err/s3err.go index 68a1081..1663644 100644 --- a/s3err/s3err.go +++ b/s3err/s3err.go @@ -770,6 +770,14 @@ func GetInvalidChecksumHeaderErr(header string) APIError { } } +func GetInvalidTrailingChecksumHeaderErr(header string) APIError { + return APIError{ + Code: "InvalidRequest", + Description: fmt.Sprintf("Value for %v trailing header is invalid.", header), + HTTPStatusCode: http.StatusBadRequest, + } +} + // Returns checksum type mismatch APIError func GetChecksumTypeMismatchErr(expected, actual types.ChecksumAlgorithm) APIError { return APIError{ @@ -783,7 +791,7 @@ func GetChecksumTypeMismatchErr(expected, actual types.ChecksumAlgorithm) APIErr func GetChecksumBadDigestErr(algo types.ChecksumAlgorithm) APIError { return APIError{ Code: "BadDigest", - Description: fmt.Sprintf("The %v you specified did not match the calculated checksum.", strings.ToLower(string(algo))), + Description: fmt.Sprintf("The %v you specified did not match the calculated checksum.", algo), HTTPStatusCode: http.StatusBadRequest, } }