Merge pull request #116 from versity/ben/fix_range

fix range gets with unspecified end range
This commit is contained in:
Ben McClelland
2023-06-29 09:29:06 -07:00
committed by GitHub
5 changed files with 57 additions and 25 deletions

View File

@@ -17,7 +17,6 @@ package backend
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"strconv"
@@ -25,6 +24,7 @@ import (
"time"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
@@ -55,6 +55,12 @@ func GetTimePtr(t time.Time) *time.Time {
return &t
}
var (
errInvalidRange = s3err.GetAPIError(s3err.ErrInvalidRequest)
)
// ParseRange parses input range header and returns startoffset, length, and
// error. If no endoffset specified, then length is set to -1.
func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) {
if acceptRange == "" {
return 0, file.Size(), nil
@@ -63,29 +69,34 @@ func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) {
rangeKv := strings.Split(acceptRange, "=")
if len(rangeKv) < 2 {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
bRange := strings.Split(rangeKv[1], "-")
if len(bRange) < 2 {
return 0, 0, errors.New("invalid range parameter")
if len(bRange) < 1 || len(bRange) > 2 {
return 0, 0, errInvalidRange
}
startOffset, err := strconv.ParseInt(bRange[0], 10, 64)
if err != nil {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
endOffset, err := strconv.ParseInt(bRange[1], 10, 64)
endOffset := int64(-1)
if len(bRange) == 1 || bRange[1] == "" {
return startOffset, endOffset, nil
}
endOffset, err = strconv.ParseInt(bRange[1], 10, 64)
if err != nil {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
if endOffset < startOffset {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
return int64(startOffset), int64(endOffset - startOffset + 1), nil
return startOffset, endOffset - startOffset + 1, nil
}
func GetMultipartMD5(parts []types.Part) string {

View File

@@ -891,8 +891,11 @@ func (p *Posix) GetObject(bucket, object, acceptRange string, writer io.Writer)
return nil, err
}
if startOffset+length > fi.Size() {
// TODO: is ErrInvalidRequest correct here?
if length == -1 {
length = fi.Size() - startOffset + 1
}
if startOffset+length > fi.Size()+1 {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}

View File

@@ -440,8 +440,11 @@ func (s *ScoutFS) GetObject(bucket, object, acceptRange string, writer io.Writer
return nil, err
}
if length == -1 {
length = fi.Size() - startOffset + 1
}
if startOffset+length > fi.Size() {
// TODO: is ErrInvalidRequest correct here?
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}

View File

@@ -923,7 +923,34 @@ func TestRangeGet(s *S3Conf) {
}
// bytes range is inclusive, go range for second value is not
if !isSame(b, data[100:201]) {
if !isEqual(b, data[100:201]) {
failF("%v: data mismatch of range", testname)
return
}
rangeString = "bytes=100-"
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
out, err = s3client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &bucket,
Key: &name,
Range: &rangeString,
})
defer cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
defer out.Body.Close()
b, err = io.ReadAll(out.Body)
if err != nil {
failF("%v: read body %v", testname, err)
return
}
// bytes range is inclusive, go range for second value is not
if !isEqual(b, data[100:]) {
failF("%v: data mismatch of range", testname)
return
}

View File

@@ -111,18 +111,6 @@ func containsPart(part int32, list []types.Part) bool {
return false
}
func isSame(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, x := range a {
if x != b[i] {
return false
}
}
return true
}
// Checks if the slices contain the same objects, if the objects doesn't
// contain map, slice, channel.
func areTagsSame(tags1, tags2 []types.Tag) bool {