From da3c6211bd7f66b15d1827bec5c66cdfb854172f Mon Sep 17 00:00:00 2001 From: niksis02 Date: Thu, 30 Jan 2025 23:41:03 +0400 Subject: [PATCH] feat: Implements streaming unsigned payload reader with trailers --- s3api/middlewares/chunk.go | 2 +- s3api/utils/auth-reader.go | 16 -- s3api/utils/chunk-reader.go | 303 ++++++--------------------- s3api/utils/signed-chunk-reader.go | 273 ++++++++++++++++++++++++ s3api/utils/unsigned-chunk-reader.go | 235 +++++++++++++++++++++ 5 files changed, 573 insertions(+), 256 deletions(-) create mode 100644 s3api/utils/signed-chunk-reader.go create mode 100644 s3api/utils/unsigned-chunk-reader.go diff --git a/s3api/middlewares/chunk.go b/s3api/middlewares/chunk.go index 24b3e5f..0f81f4c 100644 --- a/s3api/middlewares/chunk.go +++ b/s3api/middlewares/chunk.go @@ -47,7 +47,7 @@ func ProcessChunkedBody(root RootUserConfig, iam auth.IAMService, logger s3log.A if utils.IsBigDataAction(ctx) { var err error wrapBodyReader(ctx, func(r io.Reader) io.Reader { - var cr *utils.ChunkReader + var cr io.Reader cr, err = utils.NewChunkReader(ctx, r, authData, region, acct.Secret, date) return cr }) diff --git a/s3api/utils/auth-reader.go b/s3api/utils/auth-reader.go index 1694c16..2bbdce3 100644 --- a/s3api/utils/auth-reader.go +++ b/s3api/utils/auth-reader.go @@ -260,19 +260,3 @@ func removeSpace(str string) string { } return b.String() } - -var ( - specialValues = map[string]bool{ - "UNSIGNED-PAYLOAD": true, - "STREAMING-UNSIGNED-PAYLOAD-TRAILER": true, - "STREAMING-AWS4-HMAC-SHA256-PAYLOAD": true, - "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER": true, - "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD": true, - "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD-TRAILER": true, - } -) - -// IsSpecialPayload checks for streaming/unsigned authorization types -func IsSpecialPayload(str string) bool { - return specialValues[str] -} diff --git a/s3api/utils/chunk-reader.go b/s3api/utils/chunk-reader.go index 918d143..3bacaf3 100644 --- a/s3api/utils/chunk-reader.go +++ b/s3api/utils/chunk-reader.go @@ -15,260 +15,85 @@ package utils import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "errors" "fmt" - "hash" "io" - "math" - "strconv" "time" "github.com/gofiber/fiber/v2" - "github.com/versity/versitygw/s3err" ) -// chunked uploads described in: -// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html +type payloadType string const ( - chunkHdrStr = ";chunk-signature=" - chunkHdrDelim = "\r\n" - zeroLenSig = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - awsV4 = "AWS4" - awsS3Service = "s3" - awsV4Request = "aws4_request" - streamPayloadAlgo = "AWS4-HMAC-SHA256-PAYLOAD" + payloadTypeUnsigned payloadType = "UNSIGNED-PAYLOAD" + payloadTypeStreamingUnsignedTrailer payloadType = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" + payloadTypeStreamingSigned payloadType = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" + payloadTypeStreamingSignedTrailer payloadType = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER" + payloadTypeStreamingEcdsa payloadType = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD" + payloadTypeStreamingEcdsaTrailer payloadType = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD-TRAILER" ) -// ChunkReader reads from chunked upload request body, and returns -// object data stream -type ChunkReader struct { - r io.Reader - signingKey []byte - prevSig string - parsedSig string - currentChunkSize int64 - chunkDataLeft int64 - trailerExpected int - stash []byte - chunkHash hash.Hash - strToSignPrefix string - skipcheck bool -} - -// NewChunkReader reads from request body io.Reader and parses out the -// chunk metadata in stream. The headers are validated for proper signatures. -// Reading from the chunk reader will read only the object data stream -// without the chunk headers/trailers. -func NewChunkReader(ctx *fiber.Ctx, r io.Reader, authdata AuthData, region, secret string, date time.Time) (*ChunkReader, error) { - return &ChunkReader{ - r: r, - signingKey: getSigningKey(secret, region, date), - // the authdata.Signature is validated in the auth-reader, - // so we can use that here without any other checks - prevSig: authdata.Signature, - chunkHash: sha256.New(), - strToSignPrefix: getStringToSignPrefix(date, region), - }, nil -} - -// Read satisfies the io.Reader for this type -func (cr *ChunkReader) Read(p []byte) (int, error) { - n, err := cr.r.Read(p) - if err != nil && err != io.EOF { - return n, err - } - - if cr.chunkDataLeft < int64(n) { - chunkSize := cr.chunkDataLeft - if chunkSize > 0 { - cr.chunkHash.Write(p[:chunkSize]) - } - n, err := cr.parseAndRemoveChunkInfo(p[chunkSize:n]) - n += int(chunkSize) - return n, err - } - - cr.chunkDataLeft -= int64(n) - cr.chunkHash.Write(p[:n]) - return n, err -} - -// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition -// This part is the same for all chunks, -// only the previous signature and hash of current chunk changes -func getStringToSignPrefix(date time.Time, region string) string { - credentialScope := fmt.Sprintf("%s/%s/%s/%s", - date.Format("20060102"), - region, - awsS3Service, - awsV4Request) - - return fmt.Sprintf("%s\n%s\n%s", - streamPayloadAlgo, - date.Format("20060102T150405Z"), - credentialScope) -} - -// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition -// signature For each chunk, you calculate the signature using the following -// string to sign. For the first chunk, you use the seed-signature as the -// previous signature. -func getChunkStringToSign(prefix, prevSig string, chunkHash []byte) string { - return fmt.Sprintf("%s\n%s\n%s\n%s", - prefix, - prevSig, - zeroLenSig, - hex.EncodeToString(chunkHash)) -} - -// The provided p should have all of the previous chunk data and trailer -// consumed already. The positioning here is expected that p[0] starts the -// new chunk size with the ";chunk-signature=" following. The only exception -// is if we started consuming the trailer, but hit the end of the read buffer. -// In this case, parseAndRemoveChunkInfo is called with skipcheck=true to -// finish consuming the final trailer bytes. -// This parses the chunk metadata in situ without allocating an extra buffer. -// It will just read and validate the chunk metadata and then move the -// following chunk data to overwrite the metadata in the provided buffer. -func (cr *ChunkReader) parseAndRemoveChunkInfo(p []byte) (int, error) { - n := len(p) - - if !cr.skipcheck && cr.parsedSig != "" { - chunkhash := cr.chunkHash.Sum(nil) - cr.chunkHash.Reset() - - sigstr := getChunkStringToSign(cr.strToSignPrefix, cr.prevSig, chunkhash) - cr.prevSig = hex.EncodeToString(hmac256(cr.signingKey, []byte(sigstr))) - - if cr.currentChunkSize != 0 && cr.prevSig != cr.parsedSig { - return 0, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch) - } - } - - if cr.trailerExpected != 0 { - if len(p) < len(chunkHdrDelim) { - // This is the special case where we need to consume the - // trailer, but instead hit the end of the buffer. The - // subsequent call will finish consuming the trailer. - cr.chunkDataLeft = 0 - cr.trailerExpected -= len(p) - cr.skipcheck = true - return 0, nil - } - // move data up to remove trailer - copy(p, p[cr.trailerExpected:]) - n -= cr.trailerExpected - } - - cr.skipcheck = false - - chunkSize, sig, bufOffset, err := cr.parseChunkHeaderBytes(p[:n]) - cr.currentChunkSize = chunkSize - cr.parsedSig = sig - if err == errskipHeader { - cr.chunkDataLeft = 0 - return 0, nil - } - if err != nil { - return 0, err - } - if chunkSize == 0 { - return 0, io.EOF - } - - cr.trailerExpected = len(chunkHdrDelim) - - // move data up to remove chunk header - copy(p, p[bufOffset:n]) - n -= bufOffset - - // if remaining buffer larger than chunk data, - // parse next header in buffer - if int64(n) > chunkSize { - cr.chunkDataLeft = 0 - cr.chunkHash.Write(p[:chunkSize]) - n, err := cr.parseAndRemoveChunkInfo(p[chunkSize:n]) - if (chunkSize + int64(n)) > math.MaxInt { - return 0, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch) - } - return n + int(chunkSize), err - } - - cr.chunkDataLeft = chunkSize - int64(n) - cr.chunkHash.Write(p[:n]) - - return n, nil -} - -// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html -// Task 3: Calculate Signature -// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html#signing-request-intro -func getSigningKey(secret, region string, date time.Time) []byte { - dateKey := hmac256([]byte(awsV4+secret), []byte(date.Format(yyyymmdd))) - dateRegionKey := hmac256(dateKey, []byte(region)) - dateRegionServiceKey := hmac256(dateRegionKey, []byte(awsS3Service)) - signingKey := hmac256(dateRegionServiceKey, []byte(awsV4Request)) - return signingKey -} - -func hmac256(key []byte, data []byte) []byte { - hash := hmac.New(sha256.New, key) - hash.Write(data) - return hash.Sum(nil) -} - var ( - errInvalidChunkFormat = errors.New("invalid chunk header format") - errskipHeader = errors.New("skip to next header") + specialValues = map[payloadType]bool{ + payloadTypeUnsigned: true, + payloadTypeStreamingUnsignedTrailer: true, + payloadTypeStreamingSigned: true, + payloadTypeStreamingSignedTrailer: true, + payloadTypeStreamingEcdsa: true, + payloadTypeStreamingEcdsaTrailer: true, + } ) +func (pt payloadType) isValid() bool { + return pt == payloadTypeUnsigned || + pt == payloadTypeStreamingUnsignedTrailer || + pt == payloadTypeStreamingSigned || + pt == payloadTypeStreamingSignedTrailer || + pt == payloadTypeStreamingEcdsa || + pt == payloadTypeStreamingEcdsaTrailer +} + +type checksumType string + const ( - maxHeaderSize = 1024 + checksumTypeCrc32 checksumType = "x-amz-checksum-crc32" + checksumTypeCrc32c checksumType = "x-amz-checksum-crc32c" + checksumTypeSha1 checksumType = "x-amz-checksum-sha1" + checksumTypeSha256 checksumType = "x-amz-checksum-sha256" + checksumTypeCrc64nvme checksumType = "x-amz-checksum-crc64nvme" ) -// Theis returns the chunk payload size, signature, data start offset, and -// error if any. See the AWS documentation for the chunk header format. The -// header[0] byte is expected to be the first byte of the chunk size here. -func (cr *ChunkReader) parseChunkHeaderBytes(header []byte) (int64, string, int, error) { - stashLen := len(cr.stash) - if cr.stash != nil { - tmp := make([]byte, maxHeaderSize) - copy(tmp, cr.stash) - copy(tmp[len(cr.stash):], header) - header = tmp - cr.stash = nil - } - - semicolonIndex := bytes.Index(header, []byte(chunkHdrStr)) - if semicolonIndex == -1 { - cr.stash = make([]byte, len(header)) - copy(cr.stash, header) - cr.trailerExpected = 0 - return 0, "", 0, errskipHeader - } - - sigIndex := semicolonIndex + len(chunkHdrStr) - sigEndIndex := bytes.Index(header[sigIndex:], []byte(chunkHdrDelim)) - if sigEndIndex == -1 { - cr.stash = make([]byte, len(header)) - copy(cr.stash, header) - cr.trailerExpected = 0 - return 0, "", 0, errskipHeader - } - - chunkSizeBytes := header[:semicolonIndex] - chunkSize, err := strconv.ParseInt(string(chunkSizeBytes), 16, 64) - if err != nil { - return 0, "", 0, errInvalidChunkFormat - } - - signature := string(header[sigIndex:(sigIndex + sigEndIndex)]) - dataStartOffset := sigIndex + sigEndIndex + len(chunkHdrDelim) - - return chunkSize, signature, dataStartOffset - stashLen, nil +func (c checksumType) isValid() bool { + return c == checksumTypeCrc32 || + c == checksumTypeCrc32c || + c == checksumTypeSha1 || + c == checksumTypeSha256 || + c == checksumTypeCrc64nvme +} + +// IsSpecialPayload checks for streaming/unsigned authorization types +func IsSpecialPayload(str string) bool { + return specialValues[payloadType(str)] +} + +func NewChunkReader(ctx *fiber.Ctx, r io.Reader, authdata AuthData, region, secret string, date time.Time) (io.Reader, error) { + contentSha256 := payloadType(ctx.Get("X-Amz-Content-Sha256")) + if !contentSha256.isValid() { + //TODO: Add proper APIError + return nil, fmt.Errorf("invalid x-amz-content-sha256: %v", string(contentSha256)) + } + + checksumType := checksumType(ctx.Get("X-Amz-Trailer")) + if checksumType != "" && !checksumType.isValid() { + //TODO: Add proper APIError + return nil, fmt.Errorf("invalid X-Amz-Trailer: %v", checksumType) + } + + switch contentSha256 { + case payloadTypeStreamingUnsignedTrailer: + return NewUnsignedChunkReader(r, checksumType) + //TODO: Add other chunk readers + } + + return NewSignedChunkReader(r, authdata, region, secret, date) } diff --git a/s3api/utils/signed-chunk-reader.go b/s3api/utils/signed-chunk-reader.go new file mode 100644 index 0000000..2b2c6c3 --- /dev/null +++ b/s3api/utils/signed-chunk-reader.go @@ -0,0 +1,273 @@ +// Copyright 2024 Versity Software +// This file is licensed under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package utils + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash" + "io" + "math" + "strconv" + "time" + + "github.com/versity/versitygw/s3err" +) + +// chunked uploads described in: +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html + +const ( + chunkHdrStr = ";chunk-signature=" + chunkHdrDelim = "\r\n" + zeroLenSig = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + awsV4 = "AWS4" + awsS3Service = "s3" + awsV4Request = "aws4_request" + streamPayloadAlgo = "AWS4-HMAC-SHA256-PAYLOAD" +) + +// ChunkReader reads from chunked upload request body, and returns +// object data stream +type ChunkReader struct { + r io.Reader + signingKey []byte + prevSig string + parsedSig string + currentChunkSize int64 + chunkDataLeft int64 + trailerExpected int + stash []byte + chunkHash hash.Hash + strToSignPrefix string + skipcheck bool +} + +// NewChunkReader reads from request body io.Reader and parses out the +// chunk metadata in stream. The headers are validated for proper signatures. +// Reading from the chunk reader will read only the object data stream +// without the chunk headers/trailers. +func NewSignedChunkReader(r io.Reader, authdata AuthData, region, secret string, date time.Time) (io.Reader, error) { + return &ChunkReader{ + r: r, + signingKey: getSigningKey(secret, region, date), + // the authdata.Signature is validated in the auth-reader, + // so we can use that here without any other checks + prevSig: authdata.Signature, + chunkHash: sha256.New(), + strToSignPrefix: getStringToSignPrefix(date, region), + }, nil +} + +// Read satisfies the io.Reader for this type +func (cr *ChunkReader) Read(p []byte) (int, error) { + n, err := cr.r.Read(p) + if err != nil && err != io.EOF { + return n, err + } + + if cr.chunkDataLeft < int64(n) { + chunkSize := cr.chunkDataLeft + if chunkSize > 0 { + cr.chunkHash.Write(p[:chunkSize]) + } + n, err := cr.parseAndRemoveChunkInfo(p[chunkSize:n]) + n += int(chunkSize) + return n, err + } + + cr.chunkDataLeft -= int64(n) + cr.chunkHash.Write(p[:n]) + return n, err +} + +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition +// This part is the same for all chunks, +// only the previous signature and hash of current chunk changes +func getStringToSignPrefix(date time.Time, region string) string { + credentialScope := fmt.Sprintf("%s/%s/%s/%s", + date.Format("20060102"), + region, + awsS3Service, + awsV4Request) + + return fmt.Sprintf("%s\n%s\n%s", + streamPayloadAlgo, + date.Format("20060102T150405Z"), + credentialScope) +} + +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#sigv4-chunked-body-definition +// signature For each chunk, you calculate the signature using the following +// string to sign. For the first chunk, you use the seed-signature as the +// previous signature. +func getChunkStringToSign(prefix, prevSig string, chunkHash []byte) string { + return fmt.Sprintf("%s\n%s\n%s\n%s", + prefix, + prevSig, + zeroLenSig, + hex.EncodeToString(chunkHash)) +} + +// The provided p should have all of the previous chunk data and trailer +// consumed already. The positioning here is expected that p[0] starts the +// new chunk size with the ";chunk-signature=" following. The only exception +// is if we started consuming the trailer, but hit the end of the read buffer. +// In this case, parseAndRemoveChunkInfo is called with skipcheck=true to +// finish consuming the final trailer bytes. +// This parses the chunk metadata in situ without allocating an extra buffer. +// It will just read and validate the chunk metadata and then move the +// following chunk data to overwrite the metadata in the provided buffer. +func (cr *ChunkReader) parseAndRemoveChunkInfo(p []byte) (int, error) { + n := len(p) + + if !cr.skipcheck && cr.parsedSig != "" { + chunkhash := cr.chunkHash.Sum(nil) + cr.chunkHash.Reset() + + sigstr := getChunkStringToSign(cr.strToSignPrefix, cr.prevSig, chunkhash) + cr.prevSig = hex.EncodeToString(hmac256(cr.signingKey, []byte(sigstr))) + + if cr.currentChunkSize != 0 && cr.prevSig != cr.parsedSig { + return 0, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch) + } + } + + if cr.trailerExpected != 0 { + if len(p) < len(chunkHdrDelim) { + // This is the special case where we need to consume the + // trailer, but instead hit the end of the buffer. The + // subsequent call will finish consuming the trailer. + cr.chunkDataLeft = 0 + cr.trailerExpected -= len(p) + cr.skipcheck = true + return 0, nil + } + // move data up to remove trailer + copy(p, p[cr.trailerExpected:]) + n -= cr.trailerExpected + } + + cr.skipcheck = false + + chunkSize, sig, bufOffset, err := cr.parseChunkHeaderBytes(p[:n]) + cr.currentChunkSize = chunkSize + cr.parsedSig = sig + if err == errskipHeader { + cr.chunkDataLeft = 0 + return 0, nil + } + if err != nil { + return 0, err + } + if chunkSize == 0 { + return 0, io.EOF + } + + cr.trailerExpected = len(chunkHdrDelim) + + // move data up to remove chunk header + copy(p, p[bufOffset:n]) + n -= bufOffset + + // if remaining buffer larger than chunk data, + // parse next header in buffer + if int64(n) > chunkSize { + cr.chunkDataLeft = 0 + cr.chunkHash.Write(p[:chunkSize]) + n, err := cr.parseAndRemoveChunkInfo(p[chunkSize:n]) + if (chunkSize + int64(n)) > math.MaxInt { + return 0, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch) + } + return n + int(chunkSize), err + } + + cr.chunkDataLeft = chunkSize - int64(n) + cr.chunkHash.Write(p[:n]) + + return n, nil +} + +// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html +// Task 3: Calculate Signature +// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html#signing-request-intro +func getSigningKey(secret, region string, date time.Time) []byte { + dateKey := hmac256([]byte(awsV4+secret), []byte(date.Format(yyyymmdd))) + dateRegionKey := hmac256(dateKey, []byte(region)) + dateRegionServiceKey := hmac256(dateRegionKey, []byte(awsS3Service)) + signingKey := hmac256(dateRegionServiceKey, []byte(awsV4Request)) + return signingKey +} + +func hmac256(key []byte, data []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(data) + return hash.Sum(nil) +} + +var ( + errInvalidChunkFormat = errors.New("invalid chunk header format") + errskipHeader = errors.New("skip to next header") +) + +const ( + maxHeaderSize = 1024 +) + +// Theis returns the chunk payload size, signature, data start offset, and +// error if any. See the AWS documentation for the chunk header format. The +// header[0] byte is expected to be the first byte of the chunk size here. +func (cr *ChunkReader) parseChunkHeaderBytes(header []byte) (int64, string, int, error) { + stashLen := len(cr.stash) + if cr.stash != nil { + tmp := make([]byte, maxHeaderSize) + copy(tmp, cr.stash) + copy(tmp[len(cr.stash):], header) + header = tmp + cr.stash = nil + } + + semicolonIndex := bytes.Index(header, []byte(chunkHdrStr)) + if semicolonIndex == -1 { + cr.stash = make([]byte, len(header)) + copy(cr.stash, header) + cr.trailerExpected = 0 + return 0, "", 0, errskipHeader + } + + sigIndex := semicolonIndex + len(chunkHdrStr) + sigEndIndex := bytes.Index(header[sigIndex:], []byte(chunkHdrDelim)) + if sigEndIndex == -1 { + cr.stash = make([]byte, len(header)) + copy(cr.stash, header) + cr.trailerExpected = 0 + return 0, "", 0, errskipHeader + } + + chunkSizeBytes := header[:semicolonIndex] + chunkSize, err := strconv.ParseInt(string(chunkSizeBytes), 16, 64) + if err != nil { + return 0, "", 0, errInvalidChunkFormat + } + + signature := string(header[sigIndex:(sigIndex + sigEndIndex)]) + dataStartOffset := sigIndex + sigEndIndex + len(chunkHdrDelim) + + return chunkSize, signature, dataStartOffset - stashLen, nil +} diff --git a/s3api/utils/unsigned-chunk-reader.go b/s3api/utils/unsigned-chunk-reader.go new file mode 100644 index 0000000..9648d4b --- /dev/null +++ b/s3api/utils/unsigned-chunk-reader.go @@ -0,0 +1,235 @@ +// Copyright 2024 Versity Software +// This file is licensed under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package utils + +import ( + "bufio" + "bytes" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "hash" + "hash/crc32" + "hash/crc64" + "io" + "math/bits" + "strconv" + "strings" +) + +var ( + trailerDelim = []byte{'\n', '\r', '\n'} + errMalformedEncoding = errors.New("malformed chunk encoding") +) + +type UnsignedChunkReader struct { + reader *bufio.Reader + checksumType checksumType + expectedChecksum string + hasher hash.Hash + stash []byte + chunkCounter int + offset int +} + +func NewUnsignedChunkReader(r io.Reader, ct checksumType) (*UnsignedChunkReader, error) { + hasher, err := getHasher(ct) + if err != nil { + return nil, err + } + return &UnsignedChunkReader{ + reader: bufio.NewReader(r), + checksumType: ct, + stash: make([]byte, 0), + hasher: hasher, + chunkCounter: 1, + }, nil +} + +func (ucr *UnsignedChunkReader) Read(p []byte) (int, error) { + // First read any stashed data + if len(ucr.stash) != 0 { + n := copy(p, ucr.stash) + ucr.offset += n + + if n < len(ucr.stash) { + ucr.stash = ucr.stash[n:] + ucr.offset = 0 + return n, nil + } + } + + for { + // Read the chunk size + chunkSize, err := ucr.extractChunkSize() + if err != nil { + return 0, err + } + + if chunkSize == 0 { + // Stop reading parsing payloads as 0 sized chunk is reached + break + } + rdr := io.TeeReader(ucr.reader, ucr.hasher) + payload := make([]byte, chunkSize) + // Read and cache the payload + _, err = io.ReadFull(rdr, payload) + if err != nil { + return 0, err + } + + // Skip the trailing "\r\n" + if err := ucr.readAndSkip('\r', '\n'); err != nil { + return 0, err + } + + // Copy the payload into the io.Reader buffer + n := copy(p[ucr.offset:], payload) + ucr.offset += n + ucr.chunkCounter++ + + if int64(n) < chunkSize { + // stash the remaining data + ucr.stash = payload[n:] + dataRead := ucr.offset + ucr.offset = 0 + return dataRead, nil + } + } + + // Read and validate trailers + if err := ucr.readTrailer(); err != nil { + return 0, err + } + + return ucr.offset, io.EOF +} + +// Reads and validates the bytes provided from the underlying io.Reader +func (ucr *UnsignedChunkReader) readAndSkip(data ...byte) error { + for _, d := range data { + b, err := ucr.reader.ReadByte() + if err != nil { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err + } + + if b != d { + return errMalformedEncoding + } + } + + return nil +} + +// Extracts the chunk size from the payload +func (ucr *UnsignedChunkReader) extractChunkSize() (int64, error) { + line, err := ucr.reader.ReadString('\n') + if err != nil { + return 0, errMalformedEncoding + } + line = strings.TrimSpace(line) + + chunkSize, err := strconv.ParseInt(line, 16, 64) + if err != nil { + return 0, errMalformedEncoding + } + + return chunkSize, nil +} + +// Reads and validates the trailer at the end +func (ucr *UnsignedChunkReader) readTrailer() error { + var trailerBuffer bytes.Buffer + + for { + v, err := ucr.reader.ReadByte() + if err != nil { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err + } + if v != '\r' { + trailerBuffer.WriteByte(v) + continue + } + var tmp [3]byte + _, err = io.ReadFull(ucr.reader, tmp[:]) + if err != nil { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err + } + if !bytes.Equal(tmp[:], trailerDelim) { + return errMalformedEncoding + } + break + } + + // Parse the trailer + trailerHeader := trailerBuffer.String() + trailerHeader = strings.TrimSpace(trailerHeader) + trailerHeaderParts := strings.Split(trailerHeader, ":") + if len(trailerHeaderParts) != 2 { + return errMalformedEncoding + } + + if trailerHeaderParts[0] != string(ucr.checksumType) { + //TODO: handle the error + return errMalformedEncoding + } + + ucr.expectedChecksum = trailerHeaderParts[1] + + // Validate checksum + return ucr.validateChecksum() +} + +// Validates the trailing checksum sent at the end +func (ucr *UnsignedChunkReader) validateChecksum() error { + csum := ucr.hasher.Sum(nil) + checksum := base64.StdEncoding.EncodeToString(csum) + + if checksum != ucr.expectedChecksum { + return fmt.Errorf("actual checksum: %v, expected checksum: %v", checksum, ucr.expectedChecksum) + } + + return nil +} + +// Retruns the hash calculator based on the hash type provided +func getHasher(ct checksumType) (hash.Hash, error) { + switch ct { + case checksumTypeCrc32: + return crc32.NewIEEE(), nil + case checksumTypeCrc32c: + return crc32.New(crc32.MakeTable(crc32.Castagnoli)), nil + case checksumTypeCrc64nvme: + table := crc64.MakeTable(bits.Reverse64(0xad93d23594c93659)) + return crc64.New(table), nil + case checksumTypeSha1: + return sha1.New(), nil + case checksumTypeSha256: + return sha256.New(), nil + default: + return nil, errors.New("unsupported checksum type") + } +}