feat: Completed SigV4 authentication for the root user

This commit is contained in:
jonaustin09
2023-05-31 22:20:58 +04:00
parent 510cf6ed57
commit ecd28bc2f7
6 changed files with 132 additions and 31 deletions

2
go.mod
View File

@@ -6,9 +6,9 @@ require (
github.com/aws/aws-sdk-go-v2 v1.18.0
github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1
github.com/gofiber/fiber/v2 v2.45.0
github.com/valyala/fasthttp v1.47.0
github.com/google/uuid v1.3.0
github.com/pkg/xattr v0.4.9
github.com/valyala/fasthttp v1.47.0
golang.org/x/sys v0.8.0
)

View File

@@ -28,7 +28,7 @@ func New(be backend.Backend) S3ApiController {
func (c S3ApiController) ListBuckets(ctx *fiber.Ctx) error {
res, err := c.be.ListBuckets()
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
@@ -49,17 +49,17 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
}
res, err := c.be.ListObjectParts(bucket, "", uploadId, partNumberMarker, maxParts)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if ctx.Request().URI().QueryArgs().Has("acl") {
res, err := c.be.GetObjectAcl(bucket, key)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if attrs := ctx.Get("X-Amz-Object-Attributes"); attrs != "" {
res, err := c.be.GetObjectAttributes(bucket, key, strings.Split(attrs, ","))
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
acceptRange := ctx.Get("Range")
@@ -85,27 +85,27 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
}
res, err := c.be.GetObject(bucket, key, acceptRange, int64(startOffset), int64(length), ctx.Response().BodyWriter())
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) ListActions(ctx *fiber.Ctx) error {
if ctx.Request().URI().QueryArgs().Has("acl") {
res, err := c.be.GetBucketAcl(ctx.Params("bucket"))
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if ctx.Request().URI().QueryArgs().Has("uploads") {
res, err := c.be.ListMultipartUploads(&s3.ListMultipartUploadsInput{Bucket: aws.String(ctx.Params("bucket"))})
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if ctx.QueryInt("list-type") == 2 {
res, err := c.be.ListObjectsV2(ctx.Params("bucket"), "", "", "", 1)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
res, err := c.be.ListObjects(ctx.Params("bucket"), "", "", "", 1)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
@@ -134,11 +134,11 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
GrantWriteACP: &grantWriteACP,
})
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
err := c.be.PutBucket(bucket)
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
@@ -197,13 +197,13 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
CopySourceIfUnmodifiedSince: &copySrcUnmodifSinceDate,
})
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if uploadId != "" {
body := io.ReadSeeker(bytes.NewReader([]byte(ctx.Body())))
res, err := c.be.UploadPart(dstBucket, dstKeyStart, uploadId, body)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
if grants != "" || acl != "" {
@@ -221,7 +221,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
})
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
if copySource != "" {
@@ -229,7 +229,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
srcBucket, srcObject := copySourceSplit[0], copySourceSplit[1:]
res, err := c.be.CopyObject(srcBucket, strings.Join(srcObject, "/"), dstBucket, dstKeyStart)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
@@ -246,12 +246,12 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
Metadata: metadata,
Body: bytes.NewReader(ctx.Request().Body()),
})
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) DeleteBucket(ctx *fiber.Ctx) error {
err := c.be.DeleteBucket(ctx.Params("bucket"))
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error {
@@ -261,7 +261,7 @@ func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error {
}
err := c.be.DeleteObjects(ctx.Params("bucket"), &s3.DeleteObjectsInput{Delete: &dObj})
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error {
@@ -281,16 +281,16 @@ func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error {
ExpectedBucketOwner: &expectedBucketOwner,
RequestPayer: types.RequestPayer(requestPayer),
})
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
err := c.be.DeleteObject(bucket, key)
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
func (c S3ApiController) HeadBucket(ctx *fiber.Ctx) error {
res, err := c.be.HeadBucket(ctx.Params("bucket"))
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error {
@@ -300,7 +300,7 @@ func (c S3ApiController) HeadObject(ctx *fiber.Ctx) error {
}
res, err := c.be.HeadObject(bucket, key, "")
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
@@ -313,7 +313,7 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
if err := xml.Unmarshal(ctx.Body(), &restoreRequest); err == nil {
err := c.be.RestoreObject(bucket, key, &restoreRequest)
return responce[any](ctx, nil, err)
return Responce[any](ctx, nil, err)
}
if uploadId != "" {
@@ -324,13 +324,13 @@ func (c S3ApiController) CreateActions(ctx *fiber.Ctx) error {
}
res, err := c.be.CompleteMultipartUpload(bucket, "", uploadId, parts)
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
res, err := c.be.CreateMultipartUpload(&s3.CreateMultipartUploadInput{Bucket: &bucket, Key: &key})
return responce(ctx, res, err)
return Responce(ctx, res, err)
}
func responce[R comparable](ctx *fiber.Ctx, resp R, err error) error {
func Responce[R comparable](ctx *fiber.Ctx, resp R, err error) error {
if err != nil {
serr, ok := err.(s3err.APIError)
if ok {

View File

@@ -701,7 +701,7 @@ func Test_responce(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr {
if err := Responce(tt.args.ctx, tt.args.resp, tt.args.err); (err != nil) != tt.wantErr {
t.Errorf("responce() error = %v, wantErr %v", err, tt.wantErr)
}

View File

@@ -1,12 +1,85 @@
package middlewares
import (
"crypto/sha256"
"encoding/hex"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/gofiber/fiber/v2"
"github.com/versity/scoutgw/s3api/controllers"
"github.com/versity/scoutgw/s3api/utils"
"github.com/versity/scoutgw/s3err"
)
func CheckUserCreds(user utils.RootUser) fiber.Handler {
const (
iso8601Format = "20060102T150405Z"
)
func VerifyV4Signature(user utils.RootUser) fiber.Handler {
return func(ctx *fiber.Ctx) error {
authorization := ctx.Get("Authorization")
if authorization == "" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAuthHeaderEmpty))
}
// Check the signature version
authParts := strings.Split(authorization, " ")
if authParts[0] != "AWS4-HMAC-SHA256" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported))
}
creds := strings.Split(strings.Split(authParts[1], "=")[1], "/")
// Check X-Amz-Date header
date := ctx.Get("X-Amz-Date")
if date == "" {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMissingDateHeader))
}
// Parse the date and check the date validity
tdate, err := time.Parse(iso8601Format, date)
if err != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrMalformedDate))
}
// Calculate the hash of the request payload
hashedPayload := sha256.Sum256(ctx.Body())
hexPayload := hex.EncodeToString(hashedPayload[:])
hashPayloadHeader := ctx.Get("X-Amz-Content-Sha256")
// Compare the calculated hash with the hash provided
if hashPayloadHeader != hexPayload {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch))
}
// Create a new http request instance from fasthttp request
req, err := utils.CreateHttpRequestFromCtx(ctx)
if err != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAccessDenied))
}
signer := v4.NewSigner()
signErr := signer.SignHTTP(req.Context(), aws.Credentials{
AccessKeyID: user.Login,
SecretAccessKey: user.Password,
}, req, hexPayload, creds[3], creds[2], tdate)
if signErr != nil {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrAccessDenied))
}
parts := strings.Split(req.Header.Get("Authorization"), " ")
calculatedSign := strings.Split(parts[3], "=")[1]
expectedSign := strings.Split(authParts[3], "=")[1]
if expectedSign != calculatedSign {
return controllers.Responce[any](ctx, nil, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch))
}
return ctx.Next()
}
}

View File

@@ -17,9 +17,8 @@ type S3ApiServer struct {
func New(app *fiber.App, be backend.Backend, port string, rootUser utils.RootUser) (s3ApiServer *S3ApiServer, err error) {
s3ApiServer = &S3ApiServer{app, be, new(S3ApiRouter), port}
utils.GetRootUserCreds()
app.Use(middlewares.CheckUserCreds(rootUser))
app.Use(middlewares.VerifyV4Signature(rootUser))
app.Use(logger.New())
s3ApiServer.router.Init(app, be)
return

View File

@@ -1,10 +1,14 @@
package utils
import (
"bytes"
"errors"
"flag"
"net/http"
"os"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
@@ -42,3 +46,28 @@ func GetRootUserCreds() (rootUser RootUser) {
}
return
}
func CreateHttpRequestFromCtx(ctx *fiber.Ctx) (*http.Request, error) {
req := ctx.Request()
httpReq, err := http.NewRequest(string(req.Header.Method()), req.URI().String(), bytes.NewReader(req.Body()))
if err != nil {
return nil, errors.New("error in creating an http request")
}
// Set the request headers
req.Header.VisitAll(func(key, value []byte) {
keyStr := string(key)
if keyStr == "X-Amz-Date" || keyStr == "X-Amz-Content-Sha256" || keyStr == "Host" {
httpReq.Header.Add(keyStr, string(value))
}
})
// Set the Content-Length header
httpReq.ContentLength = int64(len(req.Body()))
// Set the Host header
httpReq.Host = string(req.Header.Host())
return httpReq, nil
}