diff --git a/backend/common.go b/backend/common.go index a8ddc7c..794dece 100644 --- a/backend/common.go +++ b/backend/common.go @@ -17,7 +17,6 @@ package backend import ( "crypto/md5" "encoding/hex" - "errors" "fmt" "io/fs" "strconv" @@ -25,6 +24,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/versity/versitygw/s3err" "github.com/versity/versitygw/s3response" ) @@ -55,6 +55,12 @@ func GetTimePtr(t time.Time) *time.Time { return &t } +var ( + errInvalidRange = s3err.GetAPIError(s3err.ErrInvalidRequest) +) + +// ParseRange parses input range header and returns startoffset, length, and +// error. If no endoffset specified, then length is set to -1. func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) { if acceptRange == "" { return 0, file.Size(), nil @@ -63,29 +69,34 @@ func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) { rangeKv := strings.Split(acceptRange, "=") if len(rangeKv) < 2 { - return 0, 0, errors.New("invalid range parameter") + return 0, 0, errInvalidRange } bRange := strings.Split(rangeKv[1], "-") - if len(bRange) < 2 { - return 0, 0, errors.New("invalid range parameter") + if len(bRange) < 1 || len(bRange) > 2 { + return 0, 0, errInvalidRange } startOffset, err := strconv.ParseInt(bRange[0], 10, 64) if err != nil { - return 0, 0, errors.New("invalid range parameter") + return 0, 0, errInvalidRange } - endOffset, err := strconv.ParseInt(bRange[1], 10, 64) + endOffset := int64(-1) + if len(bRange) == 1 || bRange[1] == "" { + return startOffset, endOffset, nil + } + + endOffset, err = strconv.ParseInt(bRange[1], 10, 64) if err != nil { - return 0, 0, errors.New("invalid range parameter") + return 0, 0, errInvalidRange } if endOffset < startOffset { - return 0, 0, errors.New("invalid range parameter") + return 0, 0, errInvalidRange } - return int64(startOffset), int64(endOffset - startOffset + 1), nil + return startOffset, endOffset - startOffset + 1, nil } func GetMultipartMD5(parts []types.Part) string { diff --git a/backend/posix/posix.go b/backend/posix/posix.go index 923bf78..54add78 100644 --- a/backend/posix/posix.go +++ b/backend/posix/posix.go @@ -891,8 +891,11 @@ func (p *Posix) GetObject(bucket, object, acceptRange string, writer io.Writer) return nil, err } - if startOffset+length > fi.Size() { - // TODO: is ErrInvalidRequest correct here? + if length == -1 { + length = fi.Size() - startOffset + 1 + } + + if startOffset+length > fi.Size()+1 { return nil, s3err.GetAPIError(s3err.ErrInvalidRequest) } diff --git a/backend/scoutfs/scoutfs.go b/backend/scoutfs/scoutfs.go index 533ca63..b6577c0 100644 --- a/backend/scoutfs/scoutfs.go +++ b/backend/scoutfs/scoutfs.go @@ -440,8 +440,11 @@ func (s *ScoutFS) GetObject(bucket, object, acceptRange string, writer io.Writer return nil, err } + if length == -1 { + length = fi.Size() - startOffset + 1 + } + if startOffset+length > fi.Size() { - // TODO: is ErrInvalidRequest correct here? return nil, s3err.GetAPIError(s3err.ErrInvalidRequest) } diff --git a/integration/tests.go b/integration/tests.go index a007b13..e4866cb 100644 --- a/integration/tests.go +++ b/integration/tests.go @@ -923,7 +923,34 @@ func TestRangeGet(s *S3Conf) { } // bytes range is inclusive, go range for second value is not - if !isSame(b, data[100:201]) { + if !isEqual(b, data[100:201]) { + failF("%v: data mismatch of range", testname) + return + } + + rangeString = "bytes=100-" + + ctx, cancel = context.WithTimeout(context.Background(), shortTimeout) + out, err = s3client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &bucket, + Key: &name, + Range: &rangeString, + }) + defer cancel() + if err != nil { + failF("%v: %v", testname, err) + return + } + defer out.Body.Close() + + b, err = io.ReadAll(out.Body) + if err != nil { + failF("%v: read body %v", testname, err) + return + } + + // bytes range is inclusive, go range for second value is not + if !isEqual(b, data[100:]) { failF("%v: data mismatch of range", testname) return } diff --git a/integration/utils.go b/integration/utils.go index d748b82..eef2f2a 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -111,18 +111,6 @@ func containsPart(part int32, list []types.Part) bool { return false } -func isSame(a, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, x := range a { - if x != b[i] { - return false - } - } - return true -} - // Checks if the slices contain the same objects, if the objects doesn't // contain map, slice, channel. func areTagsSame(tags1, tags2 []types.Tag) bool {