From ecd28bc2f78bc3c3a432d18d5ebd4d5501bfff70 Mon Sep 17 00:00:00 2001 From: jonaustin09 Date: Wed, 31 May 2023 22:20:58 +0400 Subject: [PATCH] feat: Completed SigV4 authentication for the root user --- go.mod | 2 +- s3api/controllers/base.go | 52 ++++++++++---------- s3api/controllers/base_test.go | 2 +- s3api/middlewares/authentication.go | 75 ++++++++++++++++++++++++++++- s3api/server.go | 3 +- s3api/utils/utils.go | 29 +++++++++++ 6 files changed, 132 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 32d30a1..88c311b 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( github.com/aws/aws-sdk-go-v2 v1.18.0 github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1 github.com/gofiber/fiber/v2 v2.45.0 - github.com/valyala/fasthttp v1.47.0 github.com/google/uuid v1.3.0 github.com/pkg/xattr v0.4.9 + github.com/valyala/fasthttp v1.47.0 golang.org/x/sys v0.8.0 ) diff --git a/s3api/controllers/base.go b/s3api/controllers/base.go index 5569e52..601a802 100644 --- a/s3api/controllers/base.go +++ b/s3api/controllers/base.go @@ -28,7 +28,7 @@ func New(be backend.Backend) S3ApiController { func (c S3ApiController) ListBuckets(ctx *fiber.Ctx) error { res, err := c.be.ListBuckets() - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { @@ -49,17 +49,17 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { } res, err := c.be.ListObjectParts(bucket, "", uploadId, partNumberMarker, maxParts) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if ctx.Request().URI().QueryArgs().Has("acl") { res, err := c.be.GetObjectAcl(bucket, key) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if attrs := ctx.Get("X-Amz-Object-Attributes"); attrs != "" { res, err := c.be.GetObjectAttributes(bucket, key, strings.Split(attrs, ",")) - return responce(ctx, res, err) + return Responce(ctx, res, err) } acceptRange := ctx.Get("Range") @@ -85,27 +85,27 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error { } res, err := c.be.GetObject(bucket, key, acceptRange, int64(startOffset), int64(length), ctx.Response().BodyWriter()) - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) ListActions(ctx *fiber.Ctx) error { if ctx.Request().URI().QueryArgs().Has("acl") { res, err := c.be.GetBucketAcl(ctx.Params("bucket")) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if ctx.Request().URI().QueryArgs().Has("uploads") { res, err := c.be.ListMultipartUploads(&s3.ListMultipartUploadsInput{Bucket: aws.String(ctx.Params("bucket"))}) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if ctx.QueryInt("list-type") == 2 { res, err := c.be.ListObjectsV2(ctx.Params("bucket"), "", "", "", 1) - return responce(ctx, res, err) + return Responce(ctx, res, err) } res, err := c.be.ListObjects(ctx.Params("bucket"), "", "", "", 1) - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { @@ -134,11 +134,11 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error { GrantWriteACP: &grantWriteACP, }) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } err := c.be.PutBucket(bucket) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { @@ -197,13 +197,13 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { CopySourceIfUnmodifiedSince: ©SrcUnmodifSinceDate, }) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if uploadId != "" { body := io.ReadSeeker(bytes.NewReader([]byte(ctx.Body()))) res, err := c.be.UploadPart(dstBucket, dstKeyStart, uploadId, body) - return responce(ctx, res, err) + return Responce(ctx, res, err) } if grants != "" || acl != "" { @@ -221,7 +221,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { GrantWrite: &granWrite, GrantWriteACP: &grantWriteACP, }) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } if copySource != "" { @@ -229,7 +229,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { srcBucket, srcObject := copySourceSplit[0], copySourceSplit[1:] res, err := c.be.CopyObject(srcBucket, strings.Join(srcObject, "/"), dstBucket, dstKeyStart) - return responce(ctx, res, err) + return Responce(ctx, res, err) } contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) @@ -246,12 +246,12 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error { Metadata: metadata, Body: bytes.NewReader(ctx.Request().Body()), }) - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) DeleteBucket(ctx *fiber.Ctx) error { err := c.be.DeleteBucket(ctx.Params("bucket")) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error { @@ -261,7 +261,7 @@ func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error { } err := c.be.DeleteObjects(ctx.Params("bucket"), &s3.DeleteObjectsInput{Delete: &dObj}) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error { @@ -281,16 +281,16 @@ func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error { ExpectedBucketOwner: &expectedBucketOwner, RequestPayer: types.RequestPayer(requestPayer), }) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } err := c.be.DeleteObject(bucket, key) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } func (c S3ApiController) HeadBucket(ctx *fiber.Ctx) error { res, err := c.be.HeadBucket(ctx.Params("bucket")) - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error { @@ -300,7 +300,7 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error { } res, err := c.be.HeadObject(bucket, key, "") - return responce(ctx, res, err) + return Responce(ctx, res, err) } func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { @@ -313,7 +313,7 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { if err := xml.Unmarshal(ctx.Body(), &restoreRequest); err == nil { err := c.be.RestoreObject(bucket, key, &restoreRequest) - return responce[any](ctx, nil, err) + return Responce[any](ctx, nil, err) } if uploadId != "" { @@ -324,13 +324,13 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error { } res, err := c.be.CompleteMultipartUpload(bucket, "", uploadId, parts) - return responce(ctx, res, err) + return Responce(ctx, res, err) } res, err := c.be.CreateMultipartUpload(&s3.CreateMultipartUploadInput{Bucket: &bucket, Key: &key}) - return responce(ctx, res, err) + return Responce(ctx, res, err) } -func responce[R comparable](ctx *fiber.Ctx, resp R, err error) error { +func Responce[R comparable](ctx *fiber.Ctx, resp R, err error) error { if err != nil { serr, ok := err.(s3err.APIError) if ok { diff --git a/s3api/controllers/base_test.go b/s3api/controllers/base_test.go index d3f1012..7963d50 100644 --- a/s3api/controllers/base_test.go +++ b/s3api/controllers/base_test.go @@ -701,7 +701,7 @@ func Test_responce(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr { + if err := Responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr { t.Errorf("responce() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/s3api/middlewares/authentication.go b/s3api/middlewares/authentication.go index fad5703..131a5ab 100644 --- a/s3api/middlewares/authentication.go +++ b/s3api/middlewares/authentication.go @@ -1,12 +1,85 @@ package middlewares import ( + "crypto/sha256" + "encoding/hex" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/gofiber/fiber/v2" + "github.com/versity/scoutgw/s3api/controllers" "github.com/versity/scoutgw/s3api/utils" + "github.com/versity/scoutgw/s3err" ) -func CheckUserCreds(user utils.RootUser) fiber.Handler { +const ( + iso8601Format = "20060102T150405Z" +) + +func VerifyV4Signature(user utils.RootUser) fiber.Handler { return func(ctx *fiber.Ctx) error { + authorization := ctx.Get("Authorization") + if authorization == "" { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty)) + } + + // Check the signature version + authParts := strings.Split(authorization, " ") + if authParts[0] != "AWS4-HMAC-SHA256" { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported)) + } + + creds := strings.Split(strings.Split(authParts[1], "=")[1], "/") + + // Check X-Amz-Date header + date := ctx.Get("X-Amz-Date") + if date == "" { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingDateHeader)) + } + + // Parse the date and check the date validity + tdate, err := time.Parse(iso8601Format, date) + if err != nil { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMalformedDate)) + } + + // Calculate the hash of the request payload + hashedPayload := sha256.Sum256(ctx.Body()) + hexPayload := hex.EncodeToString(hashedPayload[:]) + + hashPayloadHeader := ctx.Get("X-Amz-Content-Sha256") + + // Compare the calculated hash with the hash provided + if hashPayloadHeader != hexPayload { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch)) + } + + // Create a new http request instance from fasthttp request + req, err := utils.CreateHttpRequestFromCtx(ctx) + if err != nil { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAccessDenied)) + } + + signer := v4.NewSigner() + + signErr := signer.SignHTTP(req.Context(), aws.Credentials{ + AccessKeyID: user.Login, + SecretAccessKey: user.Password, + }, req, hexPayload, creds[3], creds[2], tdate) + if signErr != nil { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAccessDenied)) + } + + parts := strings.Split(req.Header.Get("Authorization"), " ") + calculatedSign := strings.Split(parts[3], "=")[1] + expectedSign := strings.Split(authParts[3], "=")[1] + + if expectedSign != calculatedSign { + return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch)) + } + return ctx.Next() } } diff --git a/s3api/server.go b/s3api/server.go index bbecbc4..fe8d9d6 100644 --- a/s3api/server.go +++ b/s3api/server.go @@ -17,9 +17,8 @@ type S3ApiServer struct { func New(app *fiber.App, be backend.Backend, port string, rootUser utils.RootUser) (s3ApiServer *S3ApiServer, err error) { s3ApiServer = &S3ApiServer{app, be, new(S3ApiRouter), port} - utils.GetRootUserCreds() - app.Use(middlewares.CheckUserCreds(rootUser)) + app.Use(middlewares.VerifyV4Signature(rootUser)) app.Use(logger.New()) s3ApiServer.router.Init(app, be) return diff --git a/s3api/utils/utils.go b/s3api/utils/utils.go index 0d7aebf..7e3f358 100644 --- a/s3api/utils/utils.go +++ b/s3api/utils/utils.go @@ -1,10 +1,14 @@ package utils import ( + "bytes" + "errors" "flag" + "net/http" "os" "strings" + "github.com/gofiber/fiber/v2" "github.com/valyala/fasthttp" ) @@ -42,3 +46,28 @@ func GetRootUserCreds() (rootUser RootUser) { } return } + +func CreateHttpRequestFromCtx(ctx *fiber.Ctx) (*http.Request, error) { + req := ctx.Request() + + httpReq, err := http.NewRequest(string(req.Header.Method()), req.URI().String(), bytes.NewReader(req.Body())) + if err != nil { + return nil, errors.New("error in creating an http request") + } + + // Set the request headers + req.Header.VisitAll(func(key, value []byte) { + keyStr := string(key) + if keyStr == "X-Amz-Date" || keyStr == "X-Amz-Content-Sha256" || keyStr == "Host" { + httpReq.Header.Add(keyStr, string(value)) + } + }) + + // Set the Content-Length header + httpReq.ContentLength = int64(len(req.Body())) + + // Set the Host header + httpReq.Host = string(req.Header.Host()) + + return httpReq, nil +}