diff --git a/s3api/controllers/base.go b/s3api/controllers/base.go index da019ee..1dce0ca 100644 --- a/s3api/controllers/base.go +++ b/s3api/controllers/base.go @@ -44,7 +44,7 @@ func New(be backend.Backend) S3ApiController { func (c S3ApiController) ListBuckets(ctx *fiber.Ctx) error { res, err := c.be.ListBuckets() - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { @@ -61,30 +61,68 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { if uploadId != "" { if maxParts < 0 || (maxParts == 0 && ctx.Query("max-parts") != "") { - return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidMaxParts)) + return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidMaxParts)) } if partNumberMarker < 0 || (partNumberMarker == 0 && ctx.Query("part-number-marker") != "") { - return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPartNumberMarker)) + return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPartNumberMarker)) } res, err := c.be.ListObjectParts(bucket, key, uploadId, partNumberMarker, maxParts) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } if ctx.Request().URI().QueryArgs().Has("acl") { res, err := c.be.GetObjectAcl(bucket, key) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } if attrs := ctx.Get("X-Amz-Object-Attributes"); attrs != "" { res, err := c.be.GetObjectAttributes(bucket, key, strings.Split(attrs, ",")) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } res, err := c.be.GetObject(bucket, key, acceptRange, ctx.Response().BodyWriter()) if err != nil { - return Responce(ctx, res, err) + return SendResponse(ctx, err) } - return nil + if res == nil { + return SendResponse(ctx, fmt.Errorf("get object nil response")) + } + + utils.SetMetaHeaders(ctx, res.Metadata) + var lastmod string + if res.LastModified != nil { + lastmod = res.LastModified.Format(timefmt) + } + utils.SetResponseHeaders(ctx, []utils.CustomHeader{ + { + Key: "Content-Length", + Value: fmt.Sprint(res.ContentLength), + }, + { + Key: "Content-Type", + Value: getstring(res.ContentType), + }, + { + Key: "Content-Encoding", + Value: getstring(res.ContentEncoding), + }, + { + Key: "ETag", + Value: getstring(res.ETag), + }, + { + Key: "Last-Modified", + Value: lastmod, + }, + }) + return ctx.SendStatus(http.StatusOK) +} + +func getstring(s *string) string { + if s == nil { + return "" + } + return *s } func (c S3ApiController) ListActions(ctx *fiber.Ctx) error { @@ -96,21 +134,21 @@ func (c S3ApiController) ListActions(ctx *fiber.Ctx) error { if ctx.Request().URI().QueryArgs().Has("acl") { res, err := c.be.GetBucketAcl(ctx.Params("bucket")) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } if ctx.Request().URI().QueryArgs().Has("uploads") { res, err := c.be.ListMultipartUploads(&s3.ListMultipartUploadsInput{Bucket: aws.String(ctx.Params("bucket"))}) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } if ctx.QueryInt("list-type") == 2 { res, err := c.be.ListObjectsV2(bucket, prefix, marker, delimiter, maxkeys) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } res, err := c.be.ListObjects(bucket, prefix, marker, delimiter, maxkeys) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { @@ -139,11 +177,11 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { GrantWriteACP: &grantWriteACP, }) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } err := c.be.PutBucket(bucket) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { @@ -186,20 +224,21 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { var err error contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) if err != nil { - return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest)) + return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest)) } } if uploadId != "" && partNumberStr != "" { partNumber := ctx.QueryInt("partNumber", -1) if partNumber < 1 { - return ErrorResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart)) + return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart)) } body := io.ReadSeeker(bytes.NewReader([]byte(ctx.Body()))) - res, err := c.be.PutObjectPart(bucket, keyStart, uploadId, + etag, err := c.be.PutObjectPart(bucket, keyStart, uploadId, partNumber, contentLength, body) - return Responce(ctx, res, err) + ctx.Response().Header.Set("Etag", etag) + return SendResponse(ctx, err) } if grants != "" || acl != "" { @@ -217,7 +256,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { GrantWrite: &granWrite, GrantWriteACP: &grantWriteACP, }) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } if copySource != "" { @@ -227,24 +266,25 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { srcBucket, srcObject := copySourceSplit[0], copySourceSplit[1:] res, err := c.be.CopyObject(srcBucket, strings.Join(srcObject, "/"), bucket, keyStart) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } metadata := utils.GetUserMetaData(&ctx.Request().Header) - res, err := c.be.PutObject(&s3.PutObjectInput{ + etag, err := c.be.PutObject(&s3.PutObjectInput{ Bucket: &bucket, Key: &keyStart, ContentLength: contentLength, Metadata: metadata, Body: bytes.NewReader(ctx.Request().Body()), }) - return Responce(ctx, res, err) + ctx.Response().Header.Set("ETag", etag) + return SendResponse(ctx, err) } func (c S3ApiController) DeleteBucket(ctx *fiber.Ctx) error { err := c.be.DeleteBucket(ctx.Params("bucket")) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error { @@ -254,7 +294,7 @@ func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error { } err := c.be.DeleteObjects(ctx.Params("bucket"), &s3.DeleteObjectsInput{Delete: &dObj}) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error { @@ -277,16 +317,17 @@ func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error { ExpectedBucketOwner: &expectedBucketOwner, RequestPayer: types.RequestPayer(requestPayer), }) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } err := c.be.DeleteObject(bucket, key) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } func (c S3ApiController) HeadBucket(ctx *fiber.Ctx) error { - res, err := c.be.HeadBucket(ctx.Params("bucket")) - return Responce(ctx, res, err) + _, err := c.be.HeadBucket(ctx.Params("bucket")) + // TODO: set bucket response headers + return SendResponse(ctx, err) } const ( @@ -303,10 +344,17 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error { res, err := c.be.HeadObject(bucket, key) if err != nil { - return ErrorResponse(ctx, err) + return SendResponse(ctx, err) + } + if res == nil { + return SendResponse(ctx, fmt.Errorf("head object nil response")) } utils.SetMetaHeaders(ctx, res.Metadata) + var lastmod string + if res.LastModified != nil { + lastmod = res.LastModified.Format(timefmt) + } utils.SetResponseHeaders(ctx, []utils.CustomHeader{ { Key: "Content-Length", @@ -314,26 +362,23 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error { }, { Key: "Content-Type", - Value: *res.ContentType, + Value: getstring(res.ContentType), }, { Key: "Content-Encoding", - Value: *res.ContentEncoding, + Value: getstring(res.ContentEncoding), }, { Key: "ETag", - Value: *res.ETag, + Value: getstring(res.ETag), }, { Key: "Last-Modified", - Value: res.LastModified.Format(timefmt), + Value: lastmod, }, }) - // https://github.com/gofiber/fiber/issues/2080 - // ctx.SendStatus() sets incorrect content length on HEAD request - ctx.Status(http.StatusOK) - return nil + return SendResponse(ctx, nil) } func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { @@ -353,7 +398,7 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { return errors.New("wrong api call") } err := c.be.RestoreObject(bucket, key, &restoreRequest) - return Responce[any](ctx, nil, err) + return SendResponse(ctx, err) } if uploadId != "" { @@ -366,13 +411,30 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { } res, err := c.be.CompleteMultipartUpload(bucket, key, uploadId, data.Parts) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } res, err := c.be.CreateMultipartUpload(&s3.CreateMultipartUploadInput{Bucket: &bucket, Key: &key}) - return Responce(ctx, res, err) + return SendXMLResponse(ctx, res, err) } -func Responce[R any](ctx *fiber.Ctx, resp R, err error) error { +func SendResponse(ctx *fiber.Ctx, err error) error { + if err != nil { + serr, ok := err.(s3err.APIError) + if ok { + ctx.Status(serr.HTTPStatusCode) + return ctx.Send(s3err.GetAPIErrorResponse(serr, "", "", "")) + } + return ctx.Send(s3err.GetAPIErrorResponse( + s3err.GetAPIError(s3err.ErrInternalError), "", "", "")) + } + + // https://github.com/gofiber/fiber/issues/2080 + // ctx.SendStatus() sets incorrect content length on HEAD request + ctx.Status(http.StatusOK) + return nil +} + +func SendXMLResponse(ctx *fiber.Ctx, resp any, err error) error { if err != nil { serr, ok := err.(s3err.APIError) if ok { @@ -388,23 +450,16 @@ func Responce[R any](ctx *fiber.Ctx, resp R, err error) error { } var b []byte - if b, err = xml.Marshal(resp); err != nil { - return err - } - if len(b) > 0 { - ctx.Response().Header.SetContentType(fiber.MIMEApplicationXML) + if resp != nil { + if b, err = xml.Marshal(resp); err != nil { + return err + } + + if len(b) > 0 { + ctx.Response().Header.SetContentType(fiber.MIMEApplicationXML) + } } return ctx.Send(b) } - -func ErrorResponse(ctx *fiber.Ctx, err error) error { - serr, ok := err.(s3err.APIError) - if ok { - ctx.Status(serr.HTTPStatusCode) - return ctx.Send(s3err.GetAPIErrorResponse(serr, "", "", "")) - } - return ctx.Send(s3err.GetAPIErrorResponse( - s3err.GetAPIError(s3err.ErrInternalError), "", "", "")) -} diff --git a/s3api/controllers/base_test.go b/s3api/controllers/base_test.go index 48df055..898ecfb 100644 --- a/s3api/controllers/base_test.go +++ b/s3api/controllers/base_test.go @@ -954,7 +954,7 @@ func TestS3ApiController_CreateActions(t *testing.T) { } } -func Test_responce(t *testing.T) { +func Test_XMLresponse(t *testing.T) { type args struct { ctx *fiber.Ctx resp any @@ -1011,14 +1011,74 @@ func Test_responce(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := Responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr { - t.Errorf("responce() error = %v, wantErr %v", err, tt.wantErr) + if err := SendXMLResponse(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr { + t.Errorf("responce() %v error = %v, wantErr %v", tt.name, err, tt.wantErr) } statusCode := tt.args.ctx.Response().StatusCode() if statusCode != tt.statusCode { - t.Errorf("responce() code = %v, wantErr %v", statusCode, tt.wantErr) + t.Errorf("responce() %v code = %v, wantErr %v", tt.name, statusCode, tt.wantErr) + } + }) + } +} + +func Test_response(t *testing.T) { + type args struct { + ctx *fiber.Ctx + resp any + err error + } + app := fiber.New() + + tests := []struct { + name string + args args + wantErr bool + statusCode int + }{ + { + name: "Internal-server-error", + args: args{ + ctx: app.AcquireCtx(&fasthttp.RequestCtx{}), + resp: nil, + err: s3err.GetAPIError(16), + }, + wantErr: false, + statusCode: 500, + }, + { + name: "Error-not-implemented", + args: args{ + ctx: app.AcquireCtx(&fasthttp.RequestCtx{}), + resp: nil, + err: s3err.GetAPIError(50), + }, + wantErr: false, + statusCode: 501, + }, + { + name: "Successful-response", + args: args{ + ctx: app.AcquireCtx(&fasthttp.RequestCtx{}), + resp: "Valid response", + err: nil, + }, + wantErr: false, + statusCode: 200, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := SendResponse(tt.args.ctx, tt.args.err); (err != nil) != tt.wantErr { + t.Errorf("responce() %v error = %v, wantErr %v", tt.name, err, tt.wantErr) + } + + statusCode := tt.args.ctx.Response().StatusCode() + + if statusCode != tt.statusCode { + t.Errorf("responce() %v code = %v, wantErr %v", tt.name, statusCode, tt.wantErr) } }) } diff --git a/s3api/middlewares/authentication.go b/s3api/middlewares/authentication.go index 4fcf629..90f1195 100644 --- a/s3api/middlewares/authentication.go +++ b/s3api/middlewares/authentication.go @@ -47,48 +47,48 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib return func(ctx *fiber.Ctx) error { authorization := ctx.Get("Authorization") if authorization == "" { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty)) } // Check the signature version authParts := strings.Split(authorization, " ") if len(authParts) < 4 { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingFields)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields)) } if authParts[0] != "AWS4-HMAC-SHA256" { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported)) } credKv := strings.Split(authParts[1], "=") if len(credKv) != 2 { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed)) } creds := strings.Split(credKv[1], "/") if len(creds) < 4 { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed)) } signHdrKv := strings.Split(authParts[2], "=") if len(signHdrKv) != 2 { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrCredMalformed)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed)) } signedHdrs := strings.Split(signHdrKv[1], ";") account := acct.getAccount(creds[0]) if account == nil { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInvalidAccessKeyID)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidAccessKeyID)) } // Check X-Amz-Date header date := ctx.Get("X-Amz-Date") if date == "" { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingDateHeader)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingDateHeader)) } // Parse the date and check the date validity tdate, err := time.Parse(iso8601Format, date) if err != nil { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMalformedDate)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMalformedDate)) } // Calculate the hash of the request payload @@ -99,13 +99,13 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib // Compare the calculated hash with the hash provided if hashPayloadHeader != hexPayload { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch)) } // Create a new http request instance from fasthttp request req, err := utils.CreateHttpRequestFromCtx(ctx, signedHdrs) if err != nil { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInternalError)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInternalError)) } signer := v4.NewSigner() @@ -120,18 +120,18 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, debug bool) fib } }) if signErr != nil { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInternalError)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInternalError)) } parts := strings.Split(req.Header.Get("Authorization"), " ") if len(parts) < 4 { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingFields)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields)) } calculatedSign := strings.Split(parts[3], "=")[1] expectedSign := strings.Split(authParts[3], "=")[1] if expectedSign != calculatedSign { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch)) } ctx.Locals("role", account.Role) diff --git a/s3api/middlewares/md5.go b/s3api/middlewares/md5.go index 33bdcdf..e5b8233 100644 --- a/s3api/middlewares/md5.go +++ b/s3api/middlewares/md5.go @@ -34,7 +34,7 @@ func VerifyMD5Body() fiber.Handler { calculatedSum := base64.StdEncoding.EncodeToString(sum[:]) if incomingSum != calculatedSum { - return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrInvalidDigest)) + return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidDigest)) } return ctx.Next()