Merge pull request #78 from versity/ben/cleanup_base

Ben/cleanup base
This commit is contained in:
Ben McClelland
2023-06-12 08:00:05 -07:00
committed by GitHub
4 changed files with 190 additions and 75 deletions

View File

@@ -44,7 +44,7 @@ func New(be backend.Backend) S3ApiController {
func (c S3ApiController) ListBuckets(ctx *fiber.Ctx) error {
res, err := c.be.ListBuckets()
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
@@ -61,30 +61,68 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
if uploadId != "" {
if maxParts < 0 || (maxParts == 0 && ctx.Query("max-parts") != "") {
return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidMaxParts))
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidMaxParts))
}
if partNumberMarker < 0 || (partNumberMarker == 0 && ctx.Query("part-number-marker") != "") {
return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPartNumberMarker))
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPartNumberMarker))
}
res, err := c.be.ListObjectParts(bucket, key, uploadId, partNumberMarker, maxParts)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
if ctx.Request().URI().QueryArgs().Has("acl") {
res, err := c.be.GetObjectAcl(bucket, key)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
if attrs := ctx.Get("X-Amz-Object-Attributes"); attrs != "" {
res, err := c.be.GetObjectAttributes(bucket, key, strings.Split(attrs, ","))
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
res, err := c.be.GetObject(bucket, key, acceptRange, ctx.Response().BodyWriter())
if err != nil {
return Responce(ctx, res, err)
return SendResponse(ctx, err)
}
return nil
if res == nil {
return SendResponse(ctx, fmt.Errorf("get object nil response"))
}
utils.SetMetaHeaders(ctx, res.Metadata)
var lastmod string
if res.LastModified != nil {
lastmod = res.LastModified.Format(timefmt)
}
utils.SetResponseHeaders(ctx, []utils.CustomHeader{
{
Key: "Content-Length",
Value: fmt.Sprint(res.ContentLength),
},
{
Key: "Content-Type",
Value: getstring(res.ContentType),
},
{
Key: "Content-Encoding",
Value: getstring(res.ContentEncoding),
},
{
Key: "ETag",
Value: getstring(res.ETag),
},
{
Key: "Last-Modified",
Value: lastmod,
},
})
return ctx.SendStatus(http.StatusOK)
}
func getstring(s *string) string {
if s == nil {
return ""
}
return *s
}
func (c S3ApiController) ListActions(ctx *fiber.Ctx) error {
@@ -96,21 +134,21 @@ func (c S3ApiController) ListActions(ctx *fiber.Ctx) error {
if ctx.Request().URI().QueryArgs().Has("acl") {
res, err := c.be.GetBucketAcl(ctx.Params("bucket"))
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
if ctx.Request().URI().QueryArgs().Has("uploads") {
res, err := c.be.ListMultipartUploads(&s3.ListMultipartUploadsInput{Bucket: aws.String(ctx.Params("bucket"))})
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
if ctx.QueryInt("list-type") == 2 {
res, err := c.be.ListObjectsV2(bucket, prefix, marker, delimiter, maxkeys)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
res, err := c.be.ListObjects(bucket, prefix, marker, delimiter, maxkeys)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
@@ -139,11 +177,11 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
GrantWriteACP: &grantWriteACP,
})
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
err := c.be.PutBucket(bucket)
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
@@ -186,20 +224,21 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
var err error
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest))
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
}
if uploadId != "" && partNumberStr != "" {
partNumber := ctx.QueryInt("partNumber", -1)
if partNumber < 1 {
return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart))
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart))
}
body := io.ReadSeeker(bytes.NewReader([]byte(ctx.Body())))
res, err := c.be.PutObjectPart(bucket, keyStart, uploadId,
etag, err := c.be.PutObjectPart(bucket, keyStart, uploadId,
partNumber, contentLength, body)
return Responce(ctx, res, err)
ctx.Response().Header.Set("Etag", etag)
return SendResponse(ctx, err)
}
if grants != "" || acl != "" {
@@ -217,7 +256,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
})
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
if copySource != "" {
@@ -227,24 +266,25 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
srcBucket, srcObject := copySourceSplit[0], copySourceSplit[1:]
res, err := c.be.CopyObject(srcBucket, strings.Join(srcObject, "/"), bucket, keyStart)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
metadata := utils.GetUserMetaData(&ctx.Request().Header)
res, err := c.be.PutObject(&s3.PutObjectInput{
etag, err := c.be.PutObject(&s3.PutObjectInput{
Bucket: &bucket,
Key: &keyStart,
ContentLength: contentLength,
Metadata: metadata,
Body: bytes.NewReader(ctx.Request().Body()),
})
return Responce(ctx, res, err)
ctx.Response().Header.Set("ETag", etag)
return SendResponse(ctx, err)
}
func (c S3ApiController) DeleteBucket(ctx *fiber.Ctx) error {
err := c.be.DeleteBucket(ctx.Params("bucket"))
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error {
@@ -254,7 +294,7 @@ func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error {
}
err := c.be.DeleteObjects(ctx.Params("bucket"), &s3.DeleteObjectsInput{Delete: &dObj})
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error {
@@ -277,16 +317,17 @@ func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error {
ExpectedBucketOwner: &expectedBucketOwner,
RequestPayer: types.RequestPayer(requestPayer),
})
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
err := c.be.DeleteObject(bucket, key)
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
func (c S3ApiController) HeadBucket(ctx *fiber.Ctx) error {
res, err := c.be.HeadBucket(ctx.Params("bucket"))
return Responce(ctx, res, err)
_, err := c.be.HeadBucket(ctx.Params("bucket"))
// TODO: set bucket response headers
return SendResponse(ctx, err)
}
const (
@@ -303,10 +344,17 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error {
res, err := c.be.HeadObject(bucket, key)
if err != nil {
return ErrorResponse(ctx, err)
return SendResponse(ctx, err)
}
if res == nil {
return SendResponse(ctx, fmt.Errorf("head object nil response"))
}
utils.SetMetaHeaders(ctx, res.Metadata)
var lastmod string
if res.LastModified != nil {
lastmod = res.LastModified.Format(timefmt)
}
utils.SetResponseHeaders(ctx, []utils.CustomHeader{
{
Key: "Content-Length",
@@ -314,26 +362,23 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error {
},
{
Key: "Content-Type",
Value: *res.ContentType,
Value: getstring(res.ContentType),
},
{
Key: "Content-Encoding",
Value: *res.ContentEncoding,
Value: getstring(res.ContentEncoding),
},
{
Key: "ETag",
Value: *res.ETag,
Value: getstring(res.ETag),
},
{
Key: "Last-Modified",
Value: res.LastModified.Format(timefmt),
Value: lastmod,
},
})
// https://github.com/gofiber/fiber/issues/2080
// ctx.SendStatus() sets incorrect content length on HEAD request
ctx.Status(http.StatusOK)
return nil
return SendResponse(ctx, nil)
}
func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
@@ -353,7 +398,7 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
return errors.New("wrong api call")
}
err := c.be.RestoreObject(bucket, key, &restoreRequest)
return Responce[any](ctx, nil, err)
return SendResponse(ctx, err)
}
if uploadId != "" {
@@ -366,13 +411,30 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
}
res, err := c.be.CompleteMultipartUpload(bucket, key, uploadId, data.Parts)
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
res, err := c.be.CreateMultipartUpload(&s3.CreateMultipartUploadInput{Bucket: &bucket, Key: &key})
return Responce(ctx, res, err)
return SendXMLResponse(ctx, res, err)
}
func Responce[R any](ctx *fiber.Ctx, resp R, err error) error {
func SendResponse(ctx *fiber.Ctx, err error) error {
if err != nil {
serr, ok := err.(s3err.APIError)
if ok {
ctx.Status(serr.HTTPStatusCode)
return ctx.Send(s3err.GetAPIErrorResponse(serr, "", "", ""))
}
return ctx.Send(s3err.GetAPIErrorResponse(
s3err.GetAPIError(s3err.ErrInternalError), "", "", ""))
}
// https://github.com/gofiber/fiber/issues/2080
// ctx.SendStatus() sets incorrect content length on HEAD request
ctx.Status(http.StatusOK)
return nil
}
func SendXMLResponse(ctx *fiber.Ctx, resp any, err error) error {
if err != nil {
serr, ok := err.(s3err.APIError)
if ok {
@@ -388,23 +450,16 @@ func Responce[R any](ctx *fiber.Ctx, resp R, err error) error {
}
var b []byte
if b, err = xml.Marshal(resp); err != nil {
return err
}
if len(b) > 0 {
ctx.Response().Header.SetContentType(fiber.MIMEApplicationXML)
if resp != nil {
if b, err = xml.Marshal(resp); err != nil {
return err
}
if len(b) > 0 {
ctx.Response().Header.SetContentType(fiber.MIMEApplicationXML)
}
}
return ctx.Send(b)
}
func ErrorResponse(ctx *fiber.Ctx, err error) error {
serr, ok := err.(s3err.APIError)
if ok {
ctx.Status(serr.HTTPStatusCode)
return ctx.Send(s3err.GetAPIErrorResponse(serr, "", "", ""))
}
return ctx.Send(s3err.GetAPIErrorResponse(
s3err.GetAPIError(s3err.ErrInternalError), "", "", ""))
}

View File

@@ -954,7 +954,7 @@ func TestS3ApiController_CreateActions(t *testing.T) {
}
}
func Test_responce(t *testing.T) {
func Test_XMLresponse(t *testing.T) {
type args struct {
ctx *fiber.Ctx
resp any
@@ -1011,14 +1011,74 @@ func Test_responce(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr {
t.Errorf("responce() error = %v, wantErr %v", err, tt.wantErr)
if err := SendXMLResponse(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr {
t.Errorf("responce() %v error = %v, wantErr %v", tt.name, err, tt.wantErr)
}
statusCode := tt.args.ctx.Response().StatusCode()
if statusCode != tt.statusCode {
t.Errorf("responce() code = %v, wantErr %v", statusCode, tt.wantErr)
t.Errorf("responce() %v code = %v, wantErr %v", tt.name, statusCode, tt.wantErr)
}
})
}
}
func Test_response(t *testing.T) {
type args struct {
ctx *fiber.Ctx
resp any
err error
}
app := fiber.New()
tests := []struct {
name string
args args
wantErr bool
statusCode int
}{
{
name: "Internal-server-error",
args: args{
ctx: app.AcquireCtx(&fasthttp.RequestCtx{}),
resp: nil,
err: s3err.GetAPIError(16),
},
wantErr: false,
statusCode: 500,
},
{
name: "Error-not-implemented",
args: args{
ctx: app.AcquireCtx(&fasthttp.RequestCtx{}),
resp: nil,
err: s3err.GetAPIError(50),
},
wantErr: false,
statusCode: 501,
},
{
name: "Successful-response",
args: args{
ctx: app.AcquireCtx(&fasthttp.RequestCtx{}),
resp: "Valid response",
err: nil,
},
wantErr: false,
statusCode: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := SendResponse(tt.args.ctx, tt.args.err); (err != nil) != tt.wantErr {
t.Errorf("responce() %v error = %v, wantErr %v", tt.name, err, tt.wantErr)
}
statusCode := tt.args.ctx.Response().StatusCode()
if statusCode != tt.statusCode {
t.Errorf("responce() %v code = %v, wantErr %v", tt.name, statusCode, tt.wantErr)
}
})
}

View File

@@ -47,48 +47,48 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib
return func(ctx *fiber.Ctx) error {
authorization := ctx.Get("Authorization")
if authorization == "" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty))
}
// Check the signature version
authParts := strings.Split(authorization, " ")
if len(authParts) < 4 {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingFields))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields))
}
if authParts[0] != "AWS4-HMAC-SHA256" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported))
}
credKv := strings.Split(authParts[1], "=")
if len(credKv) != 2 {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
creds := strings.Split(credKv[1], "/")
if len(creds) < 4 {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
signHdrKv := strings.Split(authParts[2], "=")
if len(signHdrKv) != 2 {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
signedHdrs := strings.Split(signHdrKv[1], ";")
account := acct.getAccount(creds[0])
if account == nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInvalidAccessKeyID))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidAccessKeyID))
}
// Check X-Amz-Date header
date := ctx.Get("X-Amz-Date")
if date == "" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingDateHeader))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingDateHeader))
}
// Parse the date and check the date validity
tdate, err := time.Parse(iso8601Format, date)
if err != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMalformedDate))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMalformedDate))
}
// Calculate the hash of the request payload
@@ -99,13 +99,13 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib
// Compare the calculated hash with the hash provided
if hashPayloadHeader != hexPayload {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch))
}
// Create a new http request instance from fasthttp request
req, err := utils.CreateHttpRequestFromCtx(ctx, signedHdrs)
if err != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInternalError))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInternalError))
}
signer := v4.NewSigner()
@@ -120,18 +120,18 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib
}
})
if signErr != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInternalError))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInternalError))
}
parts := strings.Split(req.Header.Get("Authorization"), " ")
if len(parts) < 4 {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingFields))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields))
}
calculatedSign := strings.Split(parts[3], "=")[1]
expectedSign := strings.Split(authParts[3], "=")[1]
if expectedSign != calculatedSign {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch))
}
ctx.Locals("role", account.Role)

View File

@@ -34,7 +34,7 @@ func VerifyMD5Body() fiber.Handler {
calculatedSum := base64.StdEncoding.EncodeToString(sum[:])
if incomingSum != calculatedSum {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInvalidDigest))
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidDigest))
}
return ctx.Next()