Compare commits

...

32 Commits
v0.3 ... v0.4

Author SHA1 Message Date
Ben McClelland
38ddbc4712 Merge pull request #126 from versity/admin-api-routing
Admin api routing
2023-07-06 14:42:22 -07:00
jonaustin09
cb193c42b4 fix: Up to date with main 2023-07-06 21:21:59 +04:00
jonaustin09
fbafc6b34c feat: Changed admin api http methods, some cleanup in admin cli commands, bug fix in delete user IAM service 2023-07-06 21:21:20 +04:00
Ben McClelland
d26b8856c1 Merge pull request #125 from versity/v4-auth-payload-support
V4 payload header support
2023-07-06 10:17:01 -07:00
Ben McClelland
23f738f37f Merge pull request #124 from versity/ben/copy_obj
feat: implement posix UploadCopyPart
2023-07-06 10:16:20 -07:00
jonaustin09
a10729b3ff fix: Fixed staticcheck error 2023-07-06 19:14:01 +04:00
jonaustin09
0330685c5c feat: Added support for unsigned, streamable and trailign payload header in sigv4 authentication 2023-07-06 19:03:19 +04:00
Ben McClelland
47dea2db7c feat: implement posix UploadCopyPart 2023-07-05 19:06:19 -07:00
Ben McClelland
db484eb900 Merge pull request #123 from versity/unit-testing-cleanup
Unit testing cleanup
2023-07-03 12:41:09 -07:00
Ben McClelland
140d41de40 Merge pull request #122 from versity/fe-upload-part-copy
Upload-part-copy FE
2023-07-03 12:37:19 -07:00
jonaustin09
39803cb158 feat: Some cleanup in controller unit tests, removed backend unsupported unit tests, added test cases for admin controller functions 2023-07-03 20:35:40 +04:00
jonaustin09
9c858b0396 feat: Added UploadPartCopy action in FE 2023-07-03 18:47:32 +04:00
jonaustin09
f63545c9b7 feat: Added UploadPartCopy action in FE 2023-07-03 17:14:46 +04:00
Ben McClelland
2894d4d5f3 Merge pull request #119 from versity/unit-test-coverage
Unit testing coverage
2023-06-30 12:49:06 -07:00
jonaustin09
46097fbf70 fix: Up to date with main 2023-06-30 22:06:25 +04:00
jonaustin09
9db01362a0 feat: increased unit testing coverage in controllers, utility functions and server functions. Fixed bucket owner bug in putbucketacl. 2 more minor changes in controllers 2023-06-30 22:04:46 +04:00
Ben McClelland
fbd7bce530 Merge pull request #118 from versity/ben/copy_obj
posix: cleanup extra debug output
2023-06-29 11:58:45 -07:00
Ben McClelland
7e34078d6a posix: cleanup extra debug output 2023-06-29 11:18:00 -07:00
Jon Austin
3c69c6922a Integration test cases for HeadBucket, CopyObject, DeleteObject actions (#117)
* feat: Added integration test cases for HeadBucket, CopyObject, DeleteObjects
* feat: Added logger for debugging
2023-06-29 10:40:54 -07:00
Ben McClelland
08db927634 Merge pull request #116 from versity/ben/fix_range
fix range gets with unspecified end range
2023-06-29 09:29:06 -07:00
Ben McClelland
6d99c69953 fix range gets with unspecified end range
The aws cli will send range gets of an object with ranges like
the following:
bytes=0-8388607
bytes=8388608-16777215
bytes=16777216-25165823
bytes=25165824-

The last one with the end offset unspecified just means the rest of
the object. So this fixes that case where there is only one offset
in the range.
2023-06-28 23:09:49 -07:00
Jon Austin
4bfb3d84d3 Acl integration test (#115)
* feat: Added test an integration test case for acl actions(get, put), fixed PutBucketAcl actions bugs, fixed iam bugs on getting and creating user accounts

* fix: Fixed acl unit tests

* fix: Fixed cli path in exec command in acl integration test

* fix: fixed account creation bug
2023-06-28 19:38:35 -07:00
Jon Austin
30dbd02a83 Tag actions integrations tests (#114)
* feat: Added an integration test case for for tag actions(get, put, delete)
2023-06-26 14:25:24 -07:00
Ben McClelland
f8afeec0a0 Merge pull request #112 from versity/ben/readme
update README.md with some content clarifications
2023-06-26 12:30:35 -07:00
Jon Austin
45e3c0922d Tag actions FE (#113)
* feat: Added get-object-tagging, put-object-tagging, delete-object-tagging actions in fe
2023-06-26 12:29:56 -07:00
Ben McClelland
a3f95520a8 update README.md with some content clarifications 2023-06-26 10:18:50 -07:00
Ben McClelland
c45280b7db Merge pull request #111 from versity/ben/tests
add functional tests to github actions
2023-06-26 08:36:39 -07:00
Ben McClelland
77b0759f86 fix full flow mising TestRangeGet test 2023-06-25 11:00:54 -07:00
Ben McClelland
1da0c1ceba add coverage report for actions tests 2023-06-25 10:54:24 -07:00
Ben McClelland
1d476c6d4d add signal handler for clean shutdown 2023-06-25 10:29:14 -07:00
Ben McClelland
c4f5f958eb add functional tests to github actions 2023-06-23 18:38:19 -07:00
Jon Austin
f84cfe58e7 Bench test (#110)
* feat: test CLI command set up for client side testing, test cases are corresponded with subcommands, added full-flow test case

* fix: TLS configuration removed

* feat: Added benchmark test for client side testing in the CLI

* fix: Removed unused variables

* fix: fixed staticcheck error
2023-06-23 09:55:04 -07:00
36 changed files with 2162 additions and 2394 deletions

30
.github/workflows/functional.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: functional tests
on: pull_request
jobs:
build:
name: RunTests
runs-on: ubuntu-latest
steps:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 'stable'
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: Get Dependencies
run: |
go get -v -t -d ./...
- name: Build and Run
run: |
make testbin
./runtests.sh
- name: Coverage Report
run: |
go tool covdata percent -i=/tmp/covdata

View File

@@ -24,10 +24,10 @@ jobs:
go get -v -t -d ./...
- name: Build
run: go build -o versitygw cmd/versitygw/*.go
run: make
- name: Test
run: go test -v -timeout 30s -tags=github ./...
run: go test -coverprofile profile.txt -race -v -timeout 30s -tags=github ./...
- name: Install govulncheck
run: go install golang.org/x/vuln/cmd/govulncheck@latest

2
.gitignore vendored
View File

@@ -32,3 +32,5 @@ VERSION
/versitygw.spec
*.tar
*.tar.gz
**/rand.data
/profile.txt

View File

@@ -34,6 +34,9 @@ build: $(BIN)
$(BIN):
$(GOBUILD) $(LDFLAGS) -o $(BIN) cmd/$(BIN)/*.go
testbin:
$(GOBUILD) $(LDFLAGS) -o $(BIN) -cover -race cmd/$(BIN)/*.go
.PHONY: test
test:
$(GOTEST) ./...

View File

@@ -1,4 +1,4 @@
# The Versity Gateway: A High-Performance Open Source S3 to File Translation Tool
# The Versity Gateway:<br/>A High-Performance S3 to Storage System Translation Service
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/versity/versitygw/blob/assets/assets/logo-white.svg">
@@ -8,13 +8,11 @@
[![Apache V2 License](https://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/versity/versitygw/blob/main/LICENSE)
The Versity Gateway: A High-Performance Open Source S3 to File Translation Tool
Current status: Alpha, in development not yet suited for production use
**Current status:** Alpha, in development not yet suited for production use
See project [documentation](https://github.com/versity/versitygw/wiki) on the wiki.
Versity Gateway, a simple to use tool for seamless inline translation between AWS S3 object commands and file-based storage systems. The Versity Gateway bridges the gap between S3-reliant applications and file storage systems, enabling enhanced compatibility and integration with file based systems while offering exceptional scalability.
Versity Gateway, a simple to use tool for seamless inline translation between AWS S3 object commands and storage systems. The Versity Gateway bridges the gap between S3-reliant applications and other storage systems, enabling enhanced compatibility and integration while offering exceptional scalability.
The server translates incoming S3 API requests and transforms them into equivalent operations to the backend service. By leveraging this gateway server, applications can interact with the S3-compatible API on top of already existing storage systems. This project enables leveraging existing infrastructure investments while seamlessly integrating with S3-compatible systems, offering increased flexibility and compatibility in managing data storage.

View File

@@ -42,7 +42,11 @@ type GetBucketAclOutput struct {
}
type AccessControlList struct {
Grants []types.Grant
Grants []types.Grant `xml:"Grant"`
}
type AccessControlPolicy struct {
AccessControlList AccessControlList `xml:"AccessControlList"`
Owner types.Owner
}
func ParseACL(data []byte) (ACL, error) {
@@ -80,69 +84,88 @@ func ParseACLOutput(data []byte) (GetBucketAclOutput, error) {
}, nil
}
func UpdateACL(input *s3.PutBucketAclInput, acl ACL, iam IAMService) error {
func UpdateACL(input *s3.PutBucketAclInput, acl ACL, iam IAMService) ([]byte, error) {
if input == nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
if acl.Owner != *input.AccessControlPolicy.Owner.ID {
return s3err.GetAPIError(s3err.ErrAccessDenied)
return nil, s3err.GetAPIError(s3err.ErrAccessDenied)
}
// if the ACL is specified, set the ACL, else replace the grantees
if input.ACL != "" {
acl.ACL = input.ACL
acl.Grantees = []Grantee{}
return nil
} else {
grantees := []Grantee{}
accs := []string{}
if input.GrantRead != nil {
fullControlList, readList, readACPList, writeList, writeACPList := []string{}, []string{}, []string{}, []string{}, []string{}
if *input.GrantFullControl != "" {
fullControlList = splitUnique(*input.GrantFullControl, ",")
fmt.Println(fullControlList)
for _, str := range fullControlList {
grantees = append(grantees, Grantee{Access: str, Permission: "FULL_CONTROL"})
}
}
if *input.GrantRead != "" {
readList = splitUnique(*input.GrantRead, ",")
for _, str := range readList {
grantees = append(grantees, Grantee{Access: str, Permission: "READ"})
}
}
if *input.GrantReadACP != "" {
readACPList = splitUnique(*input.GrantReadACP, ",")
for _, str := range readACPList {
grantees = append(grantees, Grantee{Access: str, Permission: "READ_ACP"})
}
}
if *input.GrantWrite != "" {
writeList = splitUnique(*input.GrantWrite, ",")
for _, str := range writeList {
grantees = append(grantees, Grantee{Access: str, Permission: "WRITE"})
}
}
if *input.GrantWriteACP != "" {
writeACPList = splitUnique(*input.GrantWriteACP, ",")
for _, str := range writeACPList {
grantees = append(grantees, Grantee{Access: str, Permission: "WRITE_ACP"})
}
}
accs = append(append(append(append(fullControlList, readList...), writeACPList...), readACPList...), writeList...)
} else {
cache := make(map[string]bool)
for _, grt := range input.AccessControlPolicy.Grants {
grantees = append(grantees, Grantee{Access: *grt.Grantee.ID, Permission: grt.Permission})
if _, ok := cache[*grt.Grantee.ID]; !ok {
cache[*grt.Grantee.ID] = true
accs = append(accs, *grt.Grantee.ID)
}
}
}
// Check if the specified accounts exist
accList, err := checkIfAccountsExist(accs, iam)
if err != nil {
return nil, err
}
if len(accList) > 0 {
return nil, fmt.Errorf("accounts does not exist: %s", strings.Join(accList, ", "))
}
acl.Grantees = grantees
acl.ACL = ""
}
grantees := []Grantee{}
fullControlList, readList, readACPList, writeList, writeACPList := []string{}, []string{}, []string{}, []string{}, []string{}
if *input.GrantFullControl != "" {
fullControlList = splitUnique(*input.GrantFullControl, ",")
fmt.Println(fullControlList)
for _, str := range fullControlList {
grantees = append(grantees, Grantee{Access: str, Permission: "FULL_CONTROL"})
}
}
if *input.GrantRead != "" {
readList = splitUnique(*input.GrantRead, ",")
for _, str := range readList {
grantees = append(grantees, Grantee{Access: str, Permission: "READ"})
}
}
if *input.GrantReadACP != "" {
readACPList = splitUnique(*input.GrantReadACP, ",")
for _, str := range readACPList {
grantees = append(grantees, Grantee{Access: str, Permission: "READ_ACP"})
}
}
if *input.GrantWrite != "" {
writeList = splitUnique(*input.GrantWrite, ",")
for _, str := range writeList {
grantees = append(grantees, Grantee{Access: str, Permission: "WRITE"})
}
}
if *input.GrantWriteACP != "" {
writeACPList = splitUnique(*input.GrantWriteACP, ",")
for _, str := range writeACPList {
grantees = append(grantees, Grantee{Access: str, Permission: "WRITE_ACP"})
}
}
accs := append(append(append(append(fullControlList, readList...), writeACPList...), readACPList...), writeList...)
// Check if the specified accounts exist
accList, err := checkIfAccountsExist(accs, iam)
result, err := json.Marshal(acl)
if err != nil {
return err
}
if len(accList) > 0 {
return fmt.Errorf("accounts does not exist: %s", strings.Join(accList, ", "))
return nil, err
}
acl.Grantees = grantees
acl.ACL = ""
return nil
return result, nil
}
func checkIfAccountsExist(accs []string, iam IAMService) ([]string, error) {
@@ -153,7 +176,7 @@ func checkIfAccountsExist(accs []string, iam IAMService) ([]string, error) {
if err != nil && err != ErrNoSuchUser {
return nil, fmt.Errorf("check user account: %w", err)
}
if err == nil {
if err == ErrNoSuchUser {
result = append(result, acc)
}
}

View File

@@ -25,6 +25,8 @@ type Account struct {
}
// IAMService is the interface for all IAM service implementations
//
//go:generate moq -out ../s3api/controllers/iam_moq_test.go -pkg controllers . IAMService
type IAMService interface {
CreateAccount(access string, account Account) error
GetUserAccount(access string) (Account, error)

View File

@@ -76,7 +76,7 @@ func (s *IAMServiceInternal) CreateAccount(access string, account Account) error
return nil, fmt.Errorf("failed to parse iam: %w", err)
}
} else {
conf.AccessAccounts = make(map[string]Account)
conf = IAMConfig{AccessAccounts: map[string]Account{}}
}
_, ok := conf.AccessAccounts[access]
@@ -85,10 +85,11 @@ func (s *IAMServiceInternal) CreateAccount(access string, account Account) error
}
conf.AccessAccounts[access] = account
b, err := json.Marshal(s.accts)
b, err := json.Marshal(conf)
if err != nil {
return nil, fmt.Errorf("failed to serialize iam: %w", err)
}
s.accts = conf
return b, nil
})
@@ -168,11 +169,13 @@ func (s *IAMServiceInternal) DeleteUserAccount(access string) error {
delete(conf.AccessAccounts, access)
b, err := json.Marshal(s.accts)
b, err := json.Marshal(conf)
if err != nil {
return nil, fmt.Errorf("failed to serialize iam: %w", err)
}
s.accts = conf
return b, nil
})
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/versity/versitygw/s3response"
)
//go:generate moq -out backend_moq_test.go . Backend
//go:generate moq -out ../s3api/controllers/backend_moq_test.go -pkg controllers . Backend
type Backend interface {
fmt.Stringer
@@ -40,18 +39,17 @@ type Backend interface {
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
CompleteMultipartUpload(bucket, object, uploadID string, parts []types.Part) (*s3.CompleteMultipartUploadOutput, error)
AbortMultipartUpload(*s3.AbortMultipartUploadInput) error
ListMultipartUploads(output *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error)
ListObjectParts(bucket, object, uploadID string, partNumberMarker int, maxParts int) (s3response.ListPartsResponse, error)
CopyPart(srcBucket, srcObject, DstBucket, uploadID, rangeHeader string, part int) (*types.CopyPartResult, error)
PutObjectPart(bucket, object, uploadID string, part int, length int64, r io.Reader) (etag string, err error)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
UploadPartCopy(*s3.UploadPartCopyInput) (s3response.CopyObjectResult, error)
PutObject(*s3.PutObjectInput) (string, error)
HeadObject(bucket, object string) (*s3.HeadObjectOutput, error)
GetObject(bucket, object, acceptRange string, writer io.Writer) (*s3.GetObjectOutput, error)
GetObjectAcl(bucket, object string) (*s3.GetObjectAclOutput, error)
GetObjectAttributes(bucket, object string, attributes []string) (*s3.GetObjectAttributesOutput, error)
CopyObject(srcBucket, srcObject, DstBucket, dstObject string) (*s3.CopyObjectOutput, error)
CopyObject(srcBucket, srcObject, dstBucket, dstObject string) (*s3.CopyObjectOutput, error)
ListObjects(bucket, prefix, marker, delim string, maxkeys int) (*s3.ListObjectsOutput, error)
ListObjectsV2(bucket, prefix, marker, delim string, maxkeys int) (*s3.ListObjectsV2Output, error)
DeleteObject(bucket, object string) error
@@ -87,8 +85,8 @@ func (BackendUnsupported) PutObjectAcl(*s3.PutObjectAclInput) error {
func (BackendUnsupported) RestoreObject(bucket, object string, restoreRequest *s3.RestoreObjectInput) error {
return s3err.GetAPIError(s3err.ErrNotImplemented)
}
func (BackendUnsupported) UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error) {
return nil, s3err.GetAPIError(s3err.ErrNotImplemented)
func (BackendUnsupported) UploadPartCopy(*s3.UploadPartCopyInput) (s3response.CopyObjectResult, error) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNotImplemented)
}
func (BackendUnsupported) GetBucketAcl(bucket string) ([]byte, error) {
return nil, s3err.GetAPIError(s3err.ErrNotImplemented)
@@ -118,9 +116,6 @@ func (BackendUnsupported) ListMultipartUploads(output *s3.ListMultipartUploadsIn
func (BackendUnsupported) ListObjectParts(bucket, object, uploadID string, partNumberMarker int, maxParts int) (s3response.ListPartsResponse, error) {
return s3response.ListPartsResponse{}, s3err.GetAPIError(s3err.ErrNotImplemented)
}
func (BackendUnsupported) CopyPart(srcBucket, srcObject, DstBucket, uploadID, rangeHeader string, part int) (*types.CopyPartResult, error) {
return nil, s3err.GetAPIError(s3err.ErrNotImplemented)
}
func (BackendUnsupported) PutObjectPart(bucket, object, uploadID string, part int, length int64, r io.Reader) (etag string, err error) {
return "", s3err.GetAPIError(s3err.ErrNotImplemented)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,222 +0,0 @@
// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package backend
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
func TestBackend_ListBuckets(t *testing.T) {
type args struct {
ctx context.Context
}
type test struct {
name string
c Backend
args args
wantErr bool
}
var tests []test
tests = append(tests, test{
name: "list-Bucket",
c: &BackendMock{
ListBucketsFunc: func() (s3response.ListAllMyBucketsResult, error) {
return s3response.ListAllMyBucketsResult{
Buckets: s3response.ListAllMyBucketsList{
Bucket: []s3response.ListAllMyBucketsEntry{
{
Name: "t1",
},
},
},
}, s3err.GetAPIError(0)
},
},
args: args{
ctx: context.Background(),
},
wantErr: false,
}, test{
name: "list-Bucket-error",
c: &BackendMock{
ListBucketsFunc: func() (s3response.ListAllMyBucketsResult, error) {
return s3response.ListAllMyBucketsResult{}, s3err.GetAPIError(s3err.ErrNotImplemented)
},
},
args: args{
ctx: context.Background(),
},
wantErr: true,
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := tt.c.ListBuckets(); (err.(s3err.APIError).Code != "") != tt.wantErr {
t.Errorf("Backend.ListBuckets() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestBackend_HeadBucket(t *testing.T) {
type args struct {
ctx context.Context
BucketName string
}
type test struct {
name string
c Backend
args args
wantErr bool
}
var tests []test
tests = append(tests, test{
name: "head-buckets-error",
c: &BackendMock{
HeadBucketFunc: func(bucket string) (*s3.HeadBucketOutput, error) {
return nil, s3err.GetAPIError(s3err.ErrNotImplemented)
},
},
args: args{
ctx: context.Background(),
BucketName: "b1",
},
wantErr: true,
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := tt.c.HeadBucket(tt.args.BucketName); (err.(s3err.APIError).Code != "") != tt.wantErr {
t.Errorf("Backend.HeadBucket() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestBackend_GetBucketAcl(t *testing.T) {
type args struct {
ctx context.Context
bucketName string
}
type test struct {
name string
c Backend
args args
wantErr bool
}
var tests []test
tests = append(tests, test{
name: "get bucket acl error",
c: &BackendMock{
GetBucketAclFunc: func(bucket string) ([]byte, error) {
return nil, s3err.GetAPIError(s3err.ErrNotImplemented)
},
},
args: args{
ctx: context.Background(),
bucketName: "b1",
},
wantErr: true,
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := tt.c.GetBucketAcl(tt.args.bucketName); (err.(s3err.APIError).Code != "") != tt.wantErr {
t.Errorf("Backend.GetBucketAcl() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestBackend_PutBucket(t *testing.T) {
type args struct {
ctx context.Context
bucketName string
bucketOwner string
}
type test struct {
name string
c Backend
args args
wantErr bool
}
var tests []test
tests = append(tests, test{
name: "put bucket ",
c: &BackendMock{
PutBucketFunc: func(bucket, owner string) error {
return s3err.GetAPIError(0)
},
},
args: args{
ctx: context.Background(),
bucketName: "b1",
bucketOwner: "owner",
},
wantErr: false,
}, test{
name: "put bucket error",
c: &BackendMock{
PutBucketFunc: func(bucket, owner string) error {
return s3err.GetAPIError(s3err.ErrNotImplemented)
},
},
args: args{
ctx: context.Background(),
bucketName: "b2",
bucketOwner: "owner",
},
wantErr: true,
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.c.PutBucket(tt.args.bucketName, tt.args.bucketOwner); (err.(s3err.APIError).Code != "") != tt.wantErr {
t.Errorf("Backend.PutBucket() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestBackend_DeleteBucket(t *testing.T) {
type args struct {
ctx context.Context
bucketName string
}
type test struct {
name string
c Backend
args args
wantErr bool
}
var tests []test
tests = append(tests, test{
name: "Delete Bucket Error",
c: &BackendMock{
DeleteBucketFunc: func(bucket string) error {
return s3err.GetAPIError(s3err.ErrNotImplemented)
},
},
args: args{
ctx: context.Background(),
bucketName: "b1",
},
wantErr: true,
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.c.DeleteBucket(tt.args.bucketName); (err.(s3err.APIError).Code != "") != tt.wantErr {
t.Errorf("Backend.DeleteBucket() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -17,7 +17,6 @@ package backend
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"strconv"
@@ -25,6 +24,7 @@ import (
"time"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
@@ -55,6 +55,12 @@ func GetTimePtr(t time.Time) *time.Time {
return &t
}
var (
errInvalidRange = s3err.GetAPIError(s3err.ErrInvalidRequest)
)
// ParseRange parses input range header and returns startoffset, length, and
// error. If no endoffset specified, then length is set to -1.
func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) {
if acceptRange == "" {
return 0, file.Size(), nil
@@ -63,29 +69,34 @@ func ParseRange(file fs.FileInfo, acceptRange string) (int64, int64, error) {
rangeKv := strings.Split(acceptRange, "=")
if len(rangeKv) < 2 {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
bRange := strings.Split(rangeKv[1], "-")
if len(bRange) < 2 {
return 0, 0, errors.New("invalid range parameter")
if len(bRange) < 1 || len(bRange) > 2 {
return 0, 0, errInvalidRange
}
startOffset, err := strconv.ParseInt(bRange[0], 10, 64)
if err != nil {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
endOffset, err := strconv.ParseInt(bRange[1], 10, 64)
endOffset := int64(-1)
if len(bRange) == 1 || bRange[1] == "" {
return startOffset, endOffset, nil
}
endOffset, err = strconv.ParseInt(bRange[1], 10, 64)
if err != nil {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
if endOffset < startOffset {
return 0, 0, errors.New("invalid range parameter")
return 0, 0, errInvalidRange
}
return int64(startOffset), int64(endOffset - startOffset + 1), nil
return startOffset, endOffset - startOffset + 1, nil
}
func GetMultipartMD5(parts []types.Part) string {

View File

@@ -693,10 +693,6 @@ func (p *Posix) ListObjectParts(bucket, object, uploadID string, partNumberMarke
}, nil
}
// TODO: copy part
// func (p *Posix) CopyPart(srcBucket, srcObject, DstBucket, uploadID, rangeHeader string, part int) (*types.CopyPartResult, error) {
// }
func (p *Posix) PutObjectPart(bucket, object, uploadID string, part int, length int64, r io.Reader) (string, error) {
_, err := os.Stat(bucket)
if errors.Is(err, fs.ErrNotExist) {
@@ -708,6 +704,15 @@ func (p *Posix) PutObjectPart(bucket, object, uploadID string, part int, length
sum := sha256.Sum256([]byte(object))
objdir := filepath.Join(metaTmpMultipartDir, fmt.Sprintf("%x", sum))
_, err = os.Stat(filepath.Join(bucket, objdir, uploadID))
if errors.Is(err, fs.ErrNotExist) {
return "", s3err.GetAPIError(s3err.ErrNoSuchUpload)
}
if err != nil {
return "", fmt.Errorf("stat uploadid: %w", err)
}
partPath := filepath.Join(objdir, uploadID, fmt.Sprintf("%v", part))
f, err := openTmpFile(filepath.Join(bucket, objdir),
@@ -736,6 +741,111 @@ func (p *Posix) PutObjectPart(bucket, object, uploadID string, part int, length
return etag, nil
}
func (p *Posix) UploadPartCopy(upi *s3.UploadPartCopyInput) (s3response.CopyObjectResult, error) {
_, err := os.Stat(*upi.Bucket)
if errors.Is(err, fs.ErrNotExist) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNoSuchBucket)
}
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("stat bucket: %w", err)
}
sum := sha256.Sum256([]byte(*upi.Key))
objdir := filepath.Join(metaTmpMultipartDir, fmt.Sprintf("%x", sum))
_, err = os.Stat(filepath.Join(*upi.Bucket, objdir, *upi.UploadId))
if errors.Is(err, fs.ErrNotExist) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNoSuchUpload)
}
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("stat uploadid: %w", err)
}
partPath := filepath.Join(objdir, *upi.UploadId, fmt.Sprintf("%v", upi.PartNumber))
substrs := strings.SplitN(*upi.CopySource, "/", 2)
if len(substrs) != 2 {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrInvalidCopySource)
}
srcBucket := substrs[0]
srcObject := substrs[1]
_, err = os.Stat(srcBucket)
if errors.Is(err, fs.ErrNotExist) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNoSuchBucket)
}
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("stat bucket: %w", err)
}
objPath := filepath.Join(srcBucket, srcObject)
fi, err := os.Stat(objPath)
if errors.Is(err, fs.ErrNotExist) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNoSuchKey)
}
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("stat object: %w", err)
}
startOffset, length, err := backend.ParseRange(fi, *upi.CopySourceRange)
if err != nil {
return s3response.CopyObjectResult{}, err
}
if length == -1 {
length = fi.Size() - startOffset + 1
}
if startOffset+length > fi.Size()+1 {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
f, err := openTmpFile(filepath.Join(*upi.Bucket, objdir),
*upi.Bucket, partPath, length)
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("open temp file: %w", err)
}
defer f.cleanup()
srcf, err := os.Open(objPath)
if errors.Is(err, fs.ErrNotExist) {
return s3response.CopyObjectResult{}, s3err.GetAPIError(s3err.ErrNoSuchKey)
}
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("open object: %w", err)
}
defer srcf.Close()
rdr := io.NewSectionReader(srcf, startOffset, length)
hash := md5.New()
tr := io.TeeReader(rdr, hash)
_, err = io.Copy(f, tr)
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("copy part data: %w", err)
}
err = f.link()
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("link object in namespace: %w", err)
}
dataSum := hash.Sum(nil)
etag := hex.EncodeToString(dataSum)
xattr.Set(filepath.Join(*upi.Bucket, partPath), etagkey, []byte(etag))
fi, err = os.Stat(filepath.Join(*upi.Bucket, partPath))
if err != nil {
return s3response.CopyObjectResult{}, fmt.Errorf("stat part path: %w", err)
}
return s3response.CopyObjectResult{
ETag: etag,
LastModified: fi.ModTime(),
}, nil
}
func (p *Posix) PutObject(po *s3.PutObjectInput) (string, error) {
_, err := os.Stat(*po.Bucket)
if errors.Is(err, fs.ErrNotExist) {
@@ -891,8 +1001,11 @@ func (p *Posix) GetObject(bucket, object, acceptRange string, writer io.Writer)
return nil, err
}
if startOffset+length > fi.Size() {
// TODO: is ErrInvalidRequest correct here?
if length == -1 {
length = fi.Size() - startOffset + 1
}
if startOffset+length > fi.Size()+1 {
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}
@@ -975,7 +1088,7 @@ func (p *Posix) HeadObject(bucket, object string) (*s3.HeadObjectOutput, error)
}, nil
}
func (p *Posix) CopyObject(srcBucket, srcObject, DstBucket, dstObject string) (*s3.CopyObjectOutput, error) {
func (p *Posix) CopyObject(srcBucket, srcObject, dstBucket, dstObject string) (*s3.CopyObjectOutput, error) {
_, err := os.Stat(srcBucket)
if errors.Is(err, fs.ErrNotExist) {
return nil, s3err.GetAPIError(s3err.ErrNoSuchBucket)
@@ -984,7 +1097,7 @@ func (p *Posix) CopyObject(srcBucket, srcObject, DstBucket, dstObject string) (*
return nil, fmt.Errorf("stat bucket: %w", err)
}
_, err = os.Stat(DstBucket)
_, err = os.Stat(dstBucket)
if errors.Is(err, fs.ErrNotExist) {
return nil, s3err.GetAPIError(s3err.ErrNoSuchBucket)
}
@@ -1002,12 +1115,17 @@ func (p *Posix) CopyObject(srcBucket, srcObject, DstBucket, dstObject string) (*
}
defer f.Close()
etag, err := p.PutObject(&s3.PutObjectInput{Bucket: &DstBucket, Key: &dstObject, Body: f})
fInfo, err := f.Stat()
if err != nil {
return nil, fmt.Errorf("stat object: %w", err)
}
etag, err := p.PutObject(&s3.PutObjectInput{Bucket: &dstBucket, Key: &dstObject, Body: f, ContentLength: fInfo.Size()})
if err != nil {
return nil, err
}
fi, err := os.Stat(filepath.Join(DstBucket, dstObject))
fi, err := os.Stat(filepath.Join(dstBucket, dstObject))
if err != nil {
return nil, fmt.Errorf("stat dst object: %w", err)
}
@@ -1257,7 +1375,7 @@ func (p *Posix) InitIAM() error {
_, err := os.ReadFile(iamFile)
if errors.Is(err, fs.ErrNotExist) {
b, err := json.Marshal(auth.IAMConfig{})
b, err := json.Marshal(auth.IAMConfig{AccessAccounts: map[string]auth.Account{}})
if err != nil {
return fmt.Errorf("marshal default iam: %w", err)
}

View File

@@ -440,8 +440,11 @@ func (s *ScoutFS) GetObject(bucket, object, acceptRange string, writer io.Writer
return nil, err
}
if length == -1 {
length = fi.Size() - startOffset + 1
}
if startOffset+length > fi.Size() {
// TODO: is ErrInvalidRequest correct here?
return nil, s3err.GetAPIError(s3err.ErrInvalidRequest)
}

View File

@@ -65,12 +65,6 @@ func adminCommand() *cli.Command {
Required: true,
Aliases: []string{"r"},
},
&cli.StringFlag{
Name: "region",
Usage: "s3 region string for the user",
Value: "us-east-1",
Aliases: []string{"rg"},
},
},
},
{
@@ -90,40 +84,40 @@ func adminCommand() *cli.Command {
Flags: []cli.Flag{
// TODO: create a configuration file for this
&cli.StringFlag{
Name: "adminAccess",
Name: "access",
Usage: "admin access account",
EnvVars: []string{"ADMIN_ACCESS_KEY_ID", "ADMIN_ACCESS_KEY"},
Aliases: []string{"aa"},
Aliases: []string{"a"},
Destination: &adminAccess,
},
&cli.StringFlag{
Name: "adminSecret",
Name: "secret",
Usage: "admin secret access key",
EnvVars: []string{"ADMIN_SECRET_ACCESS_KEY", "ADMIN_SECRET_KEY"},
Aliases: []string{"as"},
Aliases: []string{"s"},
Destination: &adminSecret,
},
&cli.StringFlag{
Name: "adminRegion",
Name: "region",
Usage: "s3 region string",
Value: "us-east-1",
Destination: &adminRegion,
Aliases: []string{"ar"},
Aliases: []string{"r"},
},
},
}
}
func createUser(ctx *cli.Context) error {
access, secret, role, region := ctx.String("access"), ctx.String("secret"), ctx.String("role"), ctx.String("region")
if access == "" || secret == "" || region == "" {
access, secret, role := ctx.String("access"), ctx.String("secret"), ctx.String("role")
if access == "" || secret == "" {
return fmt.Errorf("invalid input parameters for the new user")
}
if role != "admin" && role != "user" {
return fmt.Errorf("invalid input parameter for role")
}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:7070/create-user?access=%v&secret=%v&role=%v&region=%v", access, secret, role, region), nil)
req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:7070/create-user?access=%v&secret=%v&role=%v", access, secret, role), nil)
if err != nil {
return fmt.Errorf("failed to send the request: %w", err)
}
@@ -163,7 +157,7 @@ func deleteUser(ctx *cli.Context) error {
return fmt.Errorf("invalid input parameter for the new user")
}
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("http://localhost:7070/delete-user?access=%v", access), nil)
req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:7070/delete-user?access=%v", access), nil)
if err != nil {
return fmt.Errorf("failed to send the request: %w", err)
}

View File

@@ -15,6 +15,7 @@
package main
import (
"context"
"crypto/tls"
"fmt"
"log"
@@ -47,6 +48,8 @@ var (
)
func main() {
setupSignalHandler()
app := initApp()
app.Commands = []*cli.Command{
@@ -56,7 +59,14 @@ func main() {
testCommand(),
}
if err := app.Run(os.Args); err != nil {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-sigDone
fmt.Fprintf(os.Stderr, "terminating signal caught, shutting down\n")
cancel()
}()
if err := app.RunContext(ctx, os.Args); err != nil {
log.Fatal(err)
}
}
@@ -134,7 +144,7 @@ func initFlags() []cli.Flag {
}
}
func runGateway(be backend.Backend, s auth.Storer) error {
func runGateway(ctx *cli.Context, be backend.Backend, s auth.Storer) error {
app := fiber.New(fiber.Config{
AppName: "versitygw",
ServerHeader: "VERSITYGW",
@@ -180,5 +190,15 @@ func runGateway(be backend.Backend, s auth.Storer) error {
return fmt.Errorf("init gateway: %v", err)
}
return srv.Serve()
c := make(chan error, 1)
go func() { c <- srv.Serve() }()
select {
case <-ctx.Done():
be.Shutdown()
return ctx.Err()
case err := <-c:
be.Shutdown()
return err
}
}

View File

@@ -49,5 +49,5 @@ func runPosix(ctx *cli.Context) error {
return fmt.Errorf("init posix: %v", err)
}
return runGateway(be, be)
return runGateway(ctx, be, be)
}

View File

@@ -69,5 +69,5 @@ func runScoutfs(ctx *cli.Context) error {
return fmt.Errorf("init scoutfs: %v", err)
}
return runGateway(be, be)
return runGateway(ctx, be, be)
}

42
cmd/versitygw/singal.go Normal file
View File

@@ -0,0 +1,42 @@
// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package main
import (
"fmt"
"os"
"os/signal"
"syscall"
)
var (
sigDone = make(chan bool, 1)
)
func setupSignalHandler() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
for sig := range sigs {
fmt.Fprintf(os.Stderr, "caught signal %v\n", sig)
switch sig {
case syscall.SIGINT, syscall.SIGTERM:
sigDone <- true
case syscall.SIGHUP:
}
}
}()
}

View File

@@ -8,9 +8,19 @@ import (
)
var (
awsID string
awsSecret string
endpoint string
awsID string
awsSecret string
endpoint string
prefix string
dstBucket string
partSize int64
objSize int64
concurrency int
files int
upload bool
download bool
pathStyle bool
checksumDisable bool
)
func testCommand() *cli.Command {
@@ -58,17 +68,20 @@ func initTestFlags() []cli.Flag {
func initTestCommands() []*cli.Command {
return []*cli.Command{
{
Name: "make-bucket",
Usage: "Test bucket creation.",
Name: "bucket-actions",
Usage: "Test bucket creation, checking the existence, deletes it.",
Description: `Calls s3 gateway create-bucket action to create a new bucket,
then calls delete-bucket action to delete the bucket.`,
calls head-bucket action to check the existence, then calls delete-bucket action to delete the bucket.`,
Action: getAction(integration.TestMakeBucket),
},
{
Name: "put-get-object",
Usage: "Test put & get object.",
Name: "object-actions",
Usage: "Test put/get/delete/copy objects.",
Description: `Creates a bucket with s3 gateway action, puts an object in it,
gets the object from the bucket, deletes both the object and bucket.`,
tries to copy into another bucket, that doesn't exist, creates the destination bucket for copying,
copies the object, get's the object to check the length and content,
get's the copied object to check the length and content, deletes all the objects inside the source bucket,
deletes both the objects and buckets.`,
Action: getAction(integration.TestPutGetObject),
},
{
@@ -147,12 +160,127 @@ func initTestCommands() []*cli.Command {
removes both the object and bucket`,
Action: getAction(integration.TestInvalidMultiParts),
},
{
Name: "object-tag-actions",
Usage: "Tests get/put/delete object tag actions.",
Description: `Creates a bucket with s3 gateway action, puts an object in it,
puts some tags for the object, gets the tags, compares the results, removes the tags,
gets the tags again, checks it to be empty, then removes both the object and bucket`,
Action: getAction(integration.TestPutGetRemoveTags),
},
{
Name: "bucket-acl-actions",
Usage: "Tests put/get bucket actions.",
Description: `Creates a bucket with s3 gateway action, puts some bucket acls
gets the acl, verifies it, then removes the bucket`,
Action: getAction(integration.TestAclActions),
},
{
Name: "full-flow",
Usage: "Tests the full flow of gateway.",
Description: `Runs all the available tests to test the full flow of the gateway.`,
Action: getAction(integration.TestFullFlow),
},
{
Name: "bench",
Usage: "Runs download/upload performance test on the gateway",
Description: `Uploads/downloads some number(specified by flags) of files with some capacity(bytes).
Logs the results to the console`,
Flags: []cli.Flag{
&cli.IntFlag{
Name: "files",
Usage: "Number of objects to read/write",
Value: 1,
Destination: &files,
},
&cli.Int64Flag{
Name: "objsize",
Usage: "Uploading object size",
Value: 0,
Destination: &objSize,
},
&cli.StringFlag{
Name: "prefix",
Usage: "Object name prefix",
Destination: &prefix,
},
&cli.BoolFlag{
Name: "upload",
Usage: "Upload data to the gateway",
Value: false,
Destination: &upload,
},
&cli.BoolFlag{
Name: "download",
Usage: "Download data to the gateway",
Value: false,
Destination: &download,
},
&cli.StringFlag{
Name: "bucket",
Usage: "Destination bucket name to read/write data",
Destination: &dstBucket,
},
&cli.Int64Flag{
Name: "partSize",
Usage: "Upload/download size per thread",
Value: 64 * 1024 * 1024,
Destination: &partSize,
},
&cli.IntFlag{
Name: "concurrency",
Usage: "Upload/download threads per object",
Value: 1,
Destination: &concurrency,
},
&cli.BoolFlag{
Name: "pathStyle",
Usage: "Use Pathstyle bucket addressing",
Value: false,
Destination: &pathStyle,
},
&cli.BoolFlag{
Name: "checksumDis",
Usage: "Disable server checksum",
Value: false,
Destination: &checksumDisable,
},
},
Action: func(ctx *cli.Context) error {
if upload && download {
return fmt.Errorf("must only specify one of upload or download")
}
if !upload && !download {
return fmt.Errorf("must specify one of upload or download")
}
if dstBucket == "" {
return fmt.Errorf("must specify bucket")
}
opts := []integration.Option{
integration.WithAccess(awsID),
integration.WithSecret(awsSecret),
integration.WithRegion(region),
integration.WithEndpoint(endpoint),
integration.WithConcurrency(concurrency),
integration.WithPartSize(partSize),
}
if debug {
opts = append(opts, integration.WithDebug())
}
if pathStyle {
opts = append(opts, integration.WithPathStyle())
}
if checksumDisable {
opts = append(opts, integration.WithDisableChecksum())
}
s3conf := integration.NewS3Conf(opts...)
return integration.TestPerformance(s3conf, upload, download, files, objSize, dstBucket, prefix)
},
},
}
}
@@ -175,6 +303,9 @@ func getAction(tf testFunc) func(*cli.Context) error {
fmt.Println()
fmt.Println("RAN:", integration.RunCount, "PASS:", integration.PassCount, "FAIL:", integration.FailCount)
if integration.FailCount > 0 {
return fmt.Errorf("test failed with %v errors", integration.FailCount)
}
return nil
}
}

View File

@@ -38,6 +38,26 @@ func (r *RReader) Sum() []byte {
return r.hash.Sum(nil)
}
type ZReader struct {
buf []byte
dataleft int
}
func NewZeroReader(totalsize, bufsize int) *ZReader {
b := make([]byte, bufsize)
return &ZReader{buf: b, dataleft: totalsize}
}
func (r *ZReader) Read(p []byte) (int, error) {
n := min(len(p), len(r.buf), r.dataleft)
r.dataleft -= n
err := error(nil)
if n == 0 {
err = io.EOF
}
return copy(p, r.buf[:n]), err
}
func min(values ...int) int {
if len(values) == 0 {
return 0
@@ -52,3 +72,13 @@ func min(values ...int) int {
return min
}
type NW struct{}
func NewNullWriter() NW {
return NW{}
}
func (NW) WriteAt(p []byte, off int64) (n int, err error) {
return len(p), nil
}

View File

@@ -2,6 +2,7 @@ package integration
import (
"context"
"io"
"log"
"net/http"
"os"
@@ -10,6 +11,8 @@ import (
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go/middleware"
)
@@ -26,10 +29,7 @@ type S3Conf struct {
}
func NewS3Conf(opts ...Option) *S3Conf {
s := &S3Conf{
PartSize: 64 * 1024 * 1024, // 64B default chunksize
Concurrency: 1, // 1 default concurrency
}
s := &S3Conf{}
for _, opt := range opts {
opt(s)
@@ -123,3 +123,31 @@ func (c *S3Conf) Config() aws.Config {
return cfg
}
func (c *S3Conf) UploadData(r io.Reader, bucket, object string) error {
uploader := manager.NewUploader(s3.NewFromConfig(c.Config()))
uploader.PartSize = c.PartSize
uploader.Concurrency = c.Concurrency
upinfo := &s3.PutObjectInput{
Body: r,
Bucket: &bucket,
Key: &object,
}
_, err := uploader.Upload(context.Background(), upinfo)
return err
}
func (c *S3Conf) DownloadData(w io.WriterAt, bucket, object string) (int64, error) {
downloader := manager.NewDownloader(s3.NewFromConfig(c.Config()))
downloader.PartSize = c.PartSize
downloader.Concurrency = c.Concurrency
downinfo := &s3.GetObjectInput{
Bucket: &bucket,
Key: &object,
}
return downloader.Download(context.Background(), w, downinfo)
}

View File

@@ -7,11 +7,12 @@ import (
"crypto/sha256"
"fmt"
"io"
"math"
"os"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
@@ -20,80 +21,35 @@ var (
shortTimeout = 10 * time.Second
)
func setup(s *S3Conf, bucket string) error {
s3client := s3.NewFromConfig(s.Config())
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: &bucket,
})
cancel()
return err
}
func teardown(s *S3Conf, bucket string) error {
s3client := s3.NewFromConfig(s.Config())
deleteObject := func(bucket, key, versionId *string) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket,
Key: key,
VersionId: versionId,
})
cancel()
if err != nil {
return fmt.Errorf("failed to delete object %v: %v", *key, err)
}
return nil
}
in := &s3.ListObjectsV2Input{Bucket: &bucket}
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.ListObjectsV2(ctx, in)
cancel()
if err != nil {
return fmt.Errorf("failed to list objects: %v", err)
}
for _, item := range out.Contents {
err = deleteObject(&bucket, item.Key, nil)
if err != nil {
return err
}
}
if out.IsTruncated {
in.ContinuationToken = out.ContinuationToken
} else {
break
}
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.DeleteBucket(ctx, &s3.DeleteBucketInput{
Bucket: &bucket,
})
cancel()
return err
}
func TestMakeBucket(s *S3Conf) {
testname := "test make bucket"
testname := "test make/head/delete bucket"
runF(testname)
s3client := s3.NewFromConfig(s.Config())
bucket := "testbucket"
err := setup(s, bucket)
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.HeadBucket(ctx, &s3.HeadBucketInput{Bucket: &bucket})
cancel()
if err == nil {
failF("%v: expected error, instead got success response", testname)
return
}
err = setup(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
passF(testname)
testname = "test delete empty bucket"
runF(testname)
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.HeadBucket(ctx, &s3.HeadBucketInput{Bucket: &bucket})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
err = teardown(s, bucket)
if err != nil {
@@ -104,10 +60,16 @@ func TestMakeBucket(s *S3Conf) {
}
func TestPutGetObject(s *S3Conf) {
testname := "test put/get object"
testname := "test put/get/delete/copy objects"
runF(testname)
bucket := "testbucket1"
dstBucket := "testdstbucket"
obj := "myobject"
obj2 := "myobject2"
copySource := bucket + "/" + obj
s3client := s3.NewFromConfig(s.Config())
err := setup(s, bucket)
if err != nil {
@@ -122,13 +84,22 @@ func TestPutGetObject(s *S3Conf) {
csum := sha256.Sum256(data)
r := bytes.NewReader(data)
name := "myobject"
s3client := s3.NewFromConfig(s.Config())
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &bucket,
Key: &name,
Key: &obj,
Body: r,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &bucket,
Key: &obj2,
Body: r,
})
cancel()
@@ -140,7 +111,7 @@ func TestPutGetObject(s *S3Conf) {
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &bucket,
Key: &name,
Key: &obj,
})
defer cancel()
if err != nil {
@@ -166,11 +137,101 @@ func TestPutGetObject(s *S3Conf) {
return
}
// Expected error: destination bucket doesn't exist
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.CopyObject(ctx, &s3.CopyObjectInput{Bucket: &dstBucket, Key: &obj, CopySource: &copySource})
cancel()
if err == nil {
failF("%v: expect bucket not found error instead got success response", testname)
return
}
err = setup(s, dstBucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.CopyObject(ctx, &s3.CopyObjectInput{Bucket: &dstBucket, Key: &obj, CopySource: &copySource})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
copyObjOut, err := s3client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &dstBucket,
Key: &obj,
})
defer cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
defer copyObjOut.Body.Close()
if copyObjOut.ContentLength != int64(datalen) {
failF("%v: content length got %v expected %v", testname, copyObjOut.ContentLength, datalen)
return
}
b, err = io.ReadAll(copyObjOut.Body)
if err != nil {
failF("%v: read body %v", testname, err)
return
}
copysum := sha256.Sum256(b)
if csum != copysum {
failF("%v: copied object checksum got %x expected %x", testname, copysum, csum)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.DeleteObjects(ctx, &s3.DeleteObjectsInput{Bucket: &bucket, Delete: &types.Delete{Objects: []types.ObjectIdentifier{{Key: &obj}, {Key: &obj2}}}})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
objCount := 0
in := &s3.ListObjectsV2Input{Bucket: &bucket}
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.ListObjectsV2(ctx, in)
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
objCount += len(out.Contents)
if out.IsTruncated {
in.ContinuationToken = out.ContinuationToken
} else {
break
}
}
if objCount != 2 {
failF("%v: expected object count %v instead got %v", testname, 2, objCount)
}
err = teardown(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
err = teardown(s, dstBucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
passF(testname)
}
@@ -193,7 +254,7 @@ func TestPutGetMPObject(s *S3Conf) {
dr := NewDataReader(datalen, 5*1024*1024)
WithPartSize(5 * 1024 * 1024)
s.PartSize = 5 * 1024 * 1024
err = uploadData(s, dr, bucket, name)
err = s.UploadData(dr, bucket, name)
if err != nil {
failF("%v: %v", testname, err)
return
@@ -244,35 +305,6 @@ func TestPutGetMPObject(s *S3Conf) {
passF(testname)
}
func isEqual(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, d := range a {
if d != b[i] {
return false
}
}
return true
}
func uploadData(s *S3Conf, r io.Reader, bucket, object string) error {
uploader := manager.NewUploader(s3.NewFromConfig(s.Config()))
uploader.PartSize = s.PartSize
uploader.Concurrency = s.Concurrency
upinfo := &s3.PutObjectInput{
Body: r,
Bucket: &bucket,
Key: &object,
}
_, err := uploader.Upload(context.Background(), upinfo)
return err
}
func TestPutDirObject(s *S3Conf) {
testname := "test put directory object"
runF(testname)
@@ -448,16 +480,6 @@ func TestListObject(s *S3Conf) {
passF(testname)
}
func contains(name string, list []types.Object) bool {
for _, item := range list {
fmt.Println(*item.Key)
if strings.EqualFold(name, *item.Key) {
return true
}
}
return false
}
func TestListAbortMultiPartObject(s *S3Conf) {
testname := "list/abort multipart objects"
runF(testname)
@@ -542,15 +564,6 @@ func TestListAbortMultiPartObject(s *S3Conf) {
passF(testname)
}
func containsUID(name, id string, list []types.MultipartUpload) bool {
for _, item := range list {
if strings.EqualFold(name, *item.Key) && strings.EqualFold(id, *item.UploadId) {
return true
}
}
return false
}
func TestListMultiParts(s *S3Conf) {
testname := "list multipart parts"
runF(testname)
@@ -921,15 +934,6 @@ func TestIncompleteMultiParts(s *S3Conf) {
passF(testname)
}
func containsPart(part int32, list []types.Part) bool {
for _, item := range list {
if item.PartNumber == part {
return true
}
}
return false
}
func TestIncompletePutObject(s *S3Conf) {
testname := "test incomplete put object"
runF(testname)
@@ -1038,7 +1042,34 @@ func TestRangeGet(s *S3Conf) {
}
// bytes range is inclusive, go range for second value is not
if !isSame(b, data[100:201]) {
if !isEqual(b, data[100:201]) {
failF("%v: data mismatch of range", testname)
return
}
rangeString = "bytes=100-"
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
out, err = s3client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &bucket,
Key: &name,
Range: &rangeString,
})
defer cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
defer out.Body.Close()
b, err = io.ReadAll(out.Body)
if err != nil {
failF("%v: read body %v", testname, err)
return
}
// bytes range is inclusive, go range for second value is not
if !isEqual(b, data[100:]) {
failF("%v: data mismatch of range", testname)
return
}
@@ -1051,18 +1082,6 @@ func TestRangeGet(s *S3Conf) {
passF(testname)
}
func isSame(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, x := range a {
if x != b[i] {
return false
}
}
return true
}
func TestInvalidMultiParts(s *S3Conf) {
testname := "invalid multipart parts"
runF(testname)
@@ -1148,6 +1167,288 @@ func TestInvalidMultiParts(s *S3Conf) {
passF(testname)
}
type prefResult struct {
elapsed time.Duration
size int64
err error
}
func TestPerformance(s *S3Conf, upload, download bool, files int, objectSize int64, bucket, prefix string) error {
var sg sync.WaitGroup
results := make([]prefResult, files)
start := time.Now()
if upload {
if objectSize == 0 {
return fmt.Errorf("must specify object size for upload")
}
if objectSize > (int64(10000) * s.PartSize) {
return fmt.Errorf("object size can not exceed 10000 * chunksize")
}
runF("performance test: upload/download objects")
for i := 0; i < files; i++ {
sg.Add(1)
go func(i int) {
var r io.Reader = NewDataReader(int(objectSize), int(s.PartSize))
start := time.Now()
err := s.UploadData(r, bucket, fmt.Sprintf("%v%v", prefix, i))
results[i].elapsed = time.Since(start)
results[i].err = err
results[i].size = objectSize
sg.Done()
}(i)
}
}
if download {
for i := 0; i < files; i++ {
sg.Add(1)
go func(i int) {
nw := NewNullWriter()
start := time.Now()
n, err := s.DownloadData(nw, bucket, fmt.Sprintf("%v%v", prefix, i))
results[i].elapsed = time.Since(start)
results[i].err = err
results[i].size = n
sg.Done()
}(i)
}
}
sg.Wait()
elapsed := time.Since(start)
var tot int64
for i, res := range results {
if res.err != nil {
failF("%v: %v\n", i, res.err)
break
}
tot += res.size
fmt.Printf("%v: %v in %v (%v MB/s)\n",
i, res.size, res.elapsed,
int(math.Ceil(float64(res.size)/res.elapsed.Seconds())/1048576))
}
fmt.Println()
passF("run perf: %v in %v (%v MB/s)\n",
tot, elapsed, int(math.Ceil(float64(tot)/elapsed.Seconds())/1048576))
return nil
}
func TestPutGetRemoveTags(s *S3Conf) {
testname := "test put/get/remove object tags"
runF(testname)
bucket := "testbucket13"
err := setup(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
obj := "myobject"
s3client := s3.NewFromConfig(s.Config())
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &bucket,
Key: &obj,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
key1 := "hello1"
key2 := "hello2"
val1 := "world1"
val2 := "world2"
tagging := types.Tagging{TagSet: []types.Tag{{Key: &key1, Value: &val1}, {Key: &key2, Value: &val2}}}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutObjectTagging(ctx, &s3.PutObjectTaggingInput{
Bucket: &bucket,
Key: &obj,
Tagging: &tagging,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{
Key: &obj,
Bucket: &bucket,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ok := areTagsSame(tagging.TagSet, out.TagSet)
if !ok {
failF("%v: expected %v instead got %v", testname, tagging.TagSet, out.TagSet)
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.DeleteObjectTagging(ctx, &s3.DeleteObjectTaggingInput{
Key: &obj,
Bucket: &bucket,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
out, err = s3client.GetObjectTagging(ctx, &s3.GetObjectTaggingInput{
Key: &obj,
Bucket: &bucket,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
if len(out.TagSet) > 0 {
failF("%v: expected empty tag set instead got %v", testname, out.TagSet)
}
err = teardown(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
passF(testname)
}
func TestAclActions(s *S3Conf) {
testname := "test put/get acl"
runF(testname)
bucket := "testbucket14"
err := setup(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
s3client := s3.NewFromConfig(s.Config())
rootAccess := s.awsID
rootSecret := s.awsSecret
s.awsID = "grt1"
s.awsSecret = "grt1secret"
userS3Client := s3.NewFromConfig(s.Config())
s.awsID = rootAccess
s.awsSecret = rootSecret
grt1 := "grt1"
grants := []types.Grant{
{
Permission: "READ",
Grantee: &types.Grantee{
ID: &grt1,
Type: "CanonicalUser",
},
},
}
succUsrCrt := "The user has been created successfully"
failUsrCrt := "failed to create a user: update iam data: account already exists"
out, err := execCommand("admin", "-a", s.awsID, "-s", s.awsSecret, "create-user", "-a", grt1, "-s", "grt1secret", "-r", "user")
if err != nil {
failF("%v: %v", err)
return
}
if !strings.Contains(string(out), succUsrCrt) && !strings.Contains(string(out), failUsrCrt) {
failF("%v: failed to create user accounts", testname)
return
}
// Validation error case
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutBucketAcl(ctx, &s3.PutBucketAclInput{
Bucket: &bucket,
AccessControlPolicy: &types.AccessControlPolicy{
Grants: grants,
},
ACL: "private",
})
cancel()
if err == nil {
failF("%v: expected validation error", testname)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = s3client.PutBucketAcl(ctx, &s3.PutBucketAclInput{
Bucket: &bucket,
AccessControlPolicy: &types.AccessControlPolicy{
Grants: grants,
Owner: &types.Owner{ID: &s.awsID},
},
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
acl, err := s3client.GetBucketAcl(ctx, &s3.GetBucketAclInput{
Bucket: &bucket,
})
cancel()
if err != nil {
failF("%v: %v", testname, err)
return
}
if *acl.Owner.ID != s.awsID {
failF("%v: expected bucket owner: %v, instead got: %v", testname, s.awsID, *acl.Owner.ID)
return
}
if !checkGrants(acl.Grants, grants) {
failF("%v: expected %v, instead got %v", testname, grants, acl.Grants)
return
}
ctx, cancel = context.WithTimeout(context.Background(), shortTimeout)
_, err = userS3Client.PutBucketAcl(ctx, &s3.PutBucketAclInput{
Bucket: &bucket,
})
cancel()
if err == nil {
failF("%v: expected acl access denied error", testname)
return
}
err = teardown(s, bucket)
if err != nil {
failF("%v: %v", testname, err)
return
}
passF(testname)
}
// Full flow test
func TestFullFlow(s *S3Conf) {
// TODO: add more test cases to get 100% coverage
@@ -1161,6 +1462,8 @@ func TestFullFlow(s *S3Conf) {
TestIncompleteMultiParts(s)
TestIncorrectMultiParts(s)
TestListAbortMultiPartObject(s)
TestListAbortMultiPartObject(s)
TestRangeGet(s)
TestInvalidMultiParts(s)
TestPutGetRemoveTags(s)
TestAclActions(s)
}

156
integration/utils.go Normal file
View File

@@ -0,0 +1,156 @@
package integration
import (
"context"
"fmt"
"os/exec"
"strings"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
)
func setup(s *S3Conf, bucket string) error {
s3client := s3.NewFromConfig(s.Config())
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: &bucket,
})
cancel()
return err
}
func teardown(s *S3Conf, bucket string) error {
s3client := s3.NewFromConfig(s.Config())
deleteObject := func(bucket, key, versionId *string) error {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: bucket,
Key: key,
VersionId: versionId,
})
cancel()
if err != nil {
return fmt.Errorf("failed to delete object %v: %v", *key, err)
}
return nil
}
in := &s3.ListObjectsV2Input{Bucket: &bucket}
for {
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
out, err := s3client.ListObjectsV2(ctx, in)
cancel()
if err != nil {
return fmt.Errorf("failed to list objects: %v", err)
}
for _, item := range out.Contents {
err = deleteObject(&bucket, item.Key, nil)
if err != nil {
return err
}
}
if out.IsTruncated {
in.ContinuationToken = out.ContinuationToken
} else {
break
}
}
ctx, cancel := context.WithTimeout(context.Background(), shortTimeout)
_, err := s3client.DeleteBucket(ctx, &s3.DeleteBucketInput{
Bucket: &bucket,
})
cancel()
return err
}
func isEqual(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, d := range a {
if d != b[i] {
return false
}
}
return true
}
func contains(name string, list []types.Object) bool {
for _, item := range list {
fmt.Println(*item.Key)
if strings.EqualFold(name, *item.Key) {
return true
}
}
return false
}
func containsUID(name, id string, list []types.MultipartUpload) bool {
for _, item := range list {
if strings.EqualFold(name, *item.Key) && strings.EqualFold(id, *item.UploadId) {
return true
}
}
return false
}
func containsPart(part int32, list []types.Part) bool {
for _, item := range list {
if item.PartNumber == part {
return true
}
}
return false
}
func areTagsSame(tags1, tags2 []types.Tag) bool {
if len(tags1) != len(tags2) {
return false
}
for _, tag := range tags1 {
if !containsTag(tag, tags2) {
return false
}
}
return true
}
func containsTag(tag types.Tag, list []types.Tag) bool {
for _, item := range list {
if *item.Key == *tag.Key && *item.Value == *tag.Value {
return true
}
}
return false
}
func checkGrants(grts1, grts2 []types.Grant) bool {
if len(grts1) != len(grts2) {
return false
}
for i, grt := range grts1 {
if grt.Permission != grts2[i].Permission {
return false
}
if *grt.Grantee.ID != *grts2[i].Grantee.ID {
return false
}
}
return true
}
func execCommand(args ...string) ([]byte, error) {
cmd := exec.Command("./versitygw", args...)
return cmd.CombinedOutput()
}

37
runtests.sh Executable file
View File

@@ -0,0 +1,37 @@
#!/bin/bash
# make temp dirs
mkdir /tmp/gw
rm -rf /tmp/covdata
mkdir /tmp/covdata
# run server in background
GOCOVERDIR=/tmp/covdata ./versitygw -a user -s pass posix /tmp/gw &
GW_PID=$!
# wait a second for server to start up
sleep 1
# check if server is still running
if ! kill -0 $GW_PID; then
echo "server no longer running"
exit 1
fi
# run tests
if ! ./versitygw test -a user -s pass -e http://127.0.0.1:7070 full-flow; then
echo "tests failed"
kill $GW_PID
exit 1
fi
# kill off server
kill $GW_PID
exit 0
# if the above binary was built with -cover enabled (make testbin),
# then the following can be used for code coverage reports:
# go tool covdata percent -i=/tmp/covdata
# go tool covdata textfmt -i=/tmp/covdata -o profile.txt
# go tool cover -html=profile.txt

View File

@@ -27,11 +27,14 @@ type AdminController struct {
func (c AdminController) CreateUser(ctx *fiber.Ctx) error {
access, secret, role := ctx.Query("access"), ctx.Query("secret"), ctx.Query("role")
requesterRole := ctx.Locals("role")
requesterRole := ctx.Locals("role").(string)
if requesterRole != "admin" {
return fmt.Errorf("access denied: only admin users have access to this resource")
}
if role != "user" && role != "admin" {
return fmt.Errorf("invalid parameters: user role have to be one of the following: 'user', 'admin'")
}
user := auth.Account{Secret: secret, Role: role}
@@ -40,13 +43,12 @@ func (c AdminController) CreateUser(ctx *fiber.Ctx) error {
return fmt.Errorf("failed to create a user: %w", err)
}
ctx.SendString("The user has been created successfully")
return nil
return ctx.SendString("The user has been created successfully")
}
func (c AdminController) DeleteUser(ctx *fiber.Ctx) error {
access := ctx.Query("access")
requesterRole := ctx.Locals("role")
requesterRole := ctx.Locals("role").(string)
if requesterRole != "admin" {
return fmt.Errorf("access denied: only admin users have access to this resource")
}
@@ -56,6 +58,5 @@ func (c AdminController) DeleteUser(ctx *fiber.Ctx) error {
return err
}
ctx.SendString("The user has been created successfully")
return nil
return ctx.SendString("The user has been deleted successfully")
}

View File

@@ -0,0 +1,173 @@
// Copyright 2023 Versity Software
// This file is licensed under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package controllers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v2"
"github.com/versity/versitygw/auth"
)
func TestAdminController_CreateUser(t *testing.T) {
type args struct {
req *http.Request
}
adminController := AdminController{
IAMService: &IAMServiceMock{
CreateAccountFunc: func(access string, account auth.Account) error {
return nil
},
},
}
app := fiber.New()
app.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("role", "admin")
return ctx.Next()
})
app.Patch("/create-user", adminController.CreateUser)
appErr := fiber.New()
appErr.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("role", "user")
return ctx.Next()
})
appErr.Patch("/create-user", adminController.CreateUser)
tests := []struct {
name string
app *fiber.App
args args
wantErr bool
statusCode int
}{
{
name: "Admin-create-user-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPatch, "/create-user?access=test&secret=test&role=user", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Admin-create-user-invalid-user-role",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPatch, "/create-user?access=test&secret=test&role=invalid", nil),
},
wantErr: false,
statusCode: 500,
},
{
name: "Admin-create-user-invalid-requester-role",
app: appErr,
args: args{
req: httptest.NewRequest(http.MethodPatch, "/create-user?access=test&secret=test&role=admin", nil),
},
wantErr: false,
statusCode: 500,
},
}
for _, tt := range tests {
resp, err := tt.app.Test(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("AdminController.CreateUser() error = %v, wantErr %v", err, tt.wantErr)
}
if resp.StatusCode != tt.statusCode {
t.Errorf("AdminController.CreateUser() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
}
}
}
func TestAdminController_DeleteUser(t *testing.T) {
type args struct {
req *http.Request
}
adminController := AdminController{
IAMService: &IAMServiceMock{
DeleteUserAccountFunc: func(access string) error {
return nil
},
},
}
app := fiber.New()
app.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("role", "admin")
return ctx.Next()
})
app.Patch("/delete-user", adminController.DeleteUser)
appErr := fiber.New()
appErr.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("role", "user")
return ctx.Next()
})
appErr.Patch("/delete-user", adminController.DeleteUser)
tests := []struct {
name string
app *fiber.App
args args
wantErr bool
statusCode int
}{
{
name: "Admin-delete-user-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPatch, "/delete-user?access=test", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Admin-delete-user-invalid-requester-role",
app: appErr,
args: args{
req: httptest.NewRequest(http.MethodPatch, "/delete-user?access=test", nil),
},
wantErr: false,
statusCode: 500,
},
}
for _, tt := range tests {
resp, err := tt.app.Test(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("AdminController.DeleteUser() error = %v, wantErr %v", err, tt.wantErr)
}
if resp.StatusCode != tt.statusCode {
t.Errorf("AdminController.DeleteUser() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
}
}
}

View File

@@ -28,12 +28,9 @@ var _ backend.Backend = &BackendMock{}
// CompleteMultipartUploadFunc: func(bucket string, object string, uploadID string, parts []types.Part) (*s3.CompleteMultipartUploadOutput, error) {
// panic("mock out the CompleteMultipartUpload method")
// },
// CopyObjectFunc: func(srcBucket string, srcObject string, DstBucket string, dstObject string) (*s3.CopyObjectOutput, error) {
// CopyObjectFunc: func(srcBucket string, srcObject string, dstBucket string, dstObject string) (*s3.CopyObjectOutput, error) {
// panic("mock out the CopyObject method")
// },
// CopyPartFunc: func(srcBucket string, srcObject string, DstBucket string, uploadID string, rangeHeader string, part int) (*types.CopyPartResult, error) {
// panic("mock out the CopyPart method")
// },
// CreateMultipartUploadFunc: func(createMultipartUploadInput *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
// panic("mock out the CreateMultipartUpload method")
// },
@@ -70,7 +67,7 @@ var _ backend.Backend = &BackendMock{}
// ListBucketsFunc: func() (s3response.ListAllMyBucketsResult, error) {
// panic("mock out the ListBuckets method")
// },
// ListMultipartUploadsFunc: func(output *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error) {
// ListMultipartUploadsFunc: func(listMultipartUploadsInput *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error) {
// panic("mock out the ListMultipartUploads method")
// },
// ListObjectPartsFunc: func(bucket string, object string, uploadID string, partNumberMarker int, maxParts int) (s3response.ListPartsResponse, error) {
@@ -112,7 +109,7 @@ var _ backend.Backend = &BackendMock{}
// StringFunc: func() string {
// panic("mock out the String method")
// },
// UploadPartCopyFunc: func(uploadPartCopyInput *s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error) {
// UploadPartCopyFunc: func(uploadPartCopyInput *s3.UploadPartCopyInput) (s3response.CopyObjectResult, error) {
// panic("mock out the UploadPartCopy method")
// },
// }
@@ -129,10 +126,7 @@ type BackendMock struct {
CompleteMultipartUploadFunc func(bucket string, object string, uploadID string, parts []types.Part) (*s3.CompleteMultipartUploadOutput, error)
// CopyObjectFunc mocks the CopyObject method.
CopyObjectFunc func(srcBucket string, srcObject string, DstBucket string, dstObject string) (*s3.CopyObjectOutput, error)
// CopyPartFunc mocks the CopyPart method.
CopyPartFunc func(srcBucket string, srcObject string, DstBucket string, uploadID string, rangeHeader string, part int) (*types.CopyPartResult, error)
CopyObjectFunc func(srcBucket string, srcObject string, dstBucket string, dstObject string) (*s3.CopyObjectOutput, error)
// CreateMultipartUploadFunc mocks the CreateMultipartUpload method.
CreateMultipartUploadFunc func(createMultipartUploadInput *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
@@ -171,7 +165,7 @@ type BackendMock struct {
ListBucketsFunc func() (s3response.ListAllMyBucketsResult, error)
// ListMultipartUploadsFunc mocks the ListMultipartUploads method.
ListMultipartUploadsFunc func(output *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error)
ListMultipartUploadsFunc func(listMultipartUploadsInput *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error)
// ListObjectPartsFunc mocks the ListObjectParts method.
ListObjectPartsFunc func(bucket string, object string, uploadID string, partNumberMarker int, maxParts int) (s3response.ListPartsResponse, error)
@@ -213,7 +207,7 @@ type BackendMock struct {
StringFunc func() string
// UploadPartCopyFunc mocks the UploadPartCopy method.
UploadPartCopyFunc func(uploadPartCopyInput *s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
UploadPartCopyFunc func(uploadPartCopyInput *s3.UploadPartCopyInput) (s3response.CopyObjectResult, error)
// calls tracks calls to the methods.
calls struct {
@@ -239,26 +233,11 @@ type BackendMock struct {
SrcBucket string
// SrcObject is the srcObject argument value.
SrcObject string
// DstBucket is the DstBucket argument value.
// DstBucket is the dstBucket argument value.
DstBucket string
// DstObject is the dstObject argument value.
DstObject string
}
// CopyPart holds details about calls to the CopyPart method.
CopyPart []struct {
// SrcBucket is the srcBucket argument value.
SrcBucket string
// SrcObject is the srcObject argument value.
SrcObject string
// DstBucket is the DstBucket argument value.
DstBucket string
// UploadID is the uploadID argument value.
UploadID string
// RangeHeader is the rangeHeader argument value.
RangeHeader string
// Part is the part argument value.
Part int
}
// CreateMultipartUpload holds details about calls to the CreateMultipartUpload method.
CreateMultipartUpload []struct {
// CreateMultipartUploadInput is the createMultipartUploadInput argument value.
@@ -339,8 +318,8 @@ type BackendMock struct {
}
// ListMultipartUploads holds details about calls to the ListMultipartUploads method.
ListMultipartUploads []struct {
// Output is the output argument value.
Output *s3.ListMultipartUploadsInput
// ListMultipartUploadsInput is the listMultipartUploadsInput argument value.
ListMultipartUploadsInput *s3.ListMultipartUploadsInput
}
// ListObjectParts holds details about calls to the ListObjectParts method.
ListObjectParts []struct {
@@ -460,7 +439,6 @@ type BackendMock struct {
lockAbortMultipartUpload sync.RWMutex
lockCompleteMultipartUpload sync.RWMutex
lockCopyObject sync.RWMutex
lockCopyPart sync.RWMutex
lockCreateMultipartUpload sync.RWMutex
lockDeleteBucket sync.RWMutex
lockDeleteObject sync.RWMutex
@@ -567,7 +545,7 @@ func (mock *BackendMock) CompleteMultipartUploadCalls() []struct {
}
// CopyObject calls CopyObjectFunc.
func (mock *BackendMock) CopyObject(srcBucket string, srcObject string, DstBucket string, dstObject string) (*s3.CopyObjectOutput, error) {
func (mock *BackendMock) CopyObject(srcBucket string, srcObject string, dstBucket string, dstObject string) (*s3.CopyObjectOutput, error) {
if mock.CopyObjectFunc == nil {
panic("BackendMock.CopyObjectFunc: method is nil but Backend.CopyObject was just called")
}
@@ -579,13 +557,13 @@ func (mock *BackendMock) CopyObject(srcBucket string, srcObject string, DstBucke
}{
SrcBucket: srcBucket,
SrcObject: srcObject,
DstBucket: DstBucket,
DstBucket: dstBucket,
DstObject: dstObject,
}
mock.lockCopyObject.Lock()
mock.calls.CopyObject = append(mock.calls.CopyObject, callInfo)
mock.lockCopyObject.Unlock()
return mock.CopyObjectFunc(srcBucket, srcObject, DstBucket, dstObject)
return mock.CopyObjectFunc(srcBucket, srcObject, dstBucket, dstObject)
}
// CopyObjectCalls gets all the calls that were made to CopyObject.
@@ -610,58 +588,6 @@ func (mock *BackendMock) CopyObjectCalls() []struct {
return calls
}
// CopyPart calls CopyPartFunc.
func (mock *BackendMock) CopyPart(srcBucket string, srcObject string, DstBucket string, uploadID string, rangeHeader string, part int) (*types.CopyPartResult, error) {
if mock.CopyPartFunc == nil {
panic("BackendMock.CopyPartFunc: method is nil but Backend.CopyPart was just called")
}
callInfo := struct {
SrcBucket string
SrcObject string
DstBucket string
UploadID string
RangeHeader string
Part int
}{
SrcBucket: srcBucket,
SrcObject: srcObject,
DstBucket: DstBucket,
UploadID: uploadID,
RangeHeader: rangeHeader,
Part: part,
}
mock.lockCopyPart.Lock()
mock.calls.CopyPart = append(mock.calls.CopyPart, callInfo)
mock.lockCopyPart.Unlock()
return mock.CopyPartFunc(srcBucket, srcObject, DstBucket, uploadID, rangeHeader, part)
}
// CopyPartCalls gets all the calls that were made to CopyPart.
// Check the length with:
//
// len(mockedBackend.CopyPartCalls())
func (mock *BackendMock) CopyPartCalls() []struct {
SrcBucket string
SrcObject string
DstBucket string
UploadID string
RangeHeader string
Part int
} {
var calls []struct {
SrcBucket string
SrcObject string
DstBucket string
UploadID string
RangeHeader string
Part int
}
mock.lockCopyPart.RLock()
calls = mock.calls.CopyPart
mock.lockCopyPart.RUnlock()
return calls
}
// CreateMultipartUpload calls CreateMultipartUploadFunc.
func (mock *BackendMock) CreateMultipartUpload(createMultipartUploadInput *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
if mock.CreateMultipartUploadFunc == nil {
@@ -1082,19 +1008,19 @@ func (mock *BackendMock) ListBucketsCalls() []struct {
}
// ListMultipartUploads calls ListMultipartUploadsFunc.
func (mock *BackendMock) ListMultipartUploads(output *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error) {
func (mock *BackendMock) ListMultipartUploads(listMultipartUploadsInput *s3.ListMultipartUploadsInput) (s3response.ListMultipartUploadsResponse, error) {
if mock.ListMultipartUploadsFunc == nil {
panic("BackendMock.ListMultipartUploadsFunc: method is nil but Backend.ListMultipartUploads was just called")
}
callInfo := struct {
Output *s3.ListMultipartUploadsInput
ListMultipartUploadsInput *s3.ListMultipartUploadsInput
}{
Output: output,
ListMultipartUploadsInput: listMultipartUploadsInput,
}
mock.lockListMultipartUploads.Lock()
mock.calls.ListMultipartUploads = append(mock.calls.ListMultipartUploads, callInfo)
mock.lockListMultipartUploads.Unlock()
return mock.ListMultipartUploadsFunc(output)
return mock.ListMultipartUploadsFunc(listMultipartUploadsInput)
}
// ListMultipartUploadsCalls gets all the calls that were made to ListMultipartUploads.
@@ -1102,10 +1028,10 @@ func (mock *BackendMock) ListMultipartUploads(output *s3.ListMultipartUploadsInp
//
// len(mockedBackend.ListMultipartUploadsCalls())
func (mock *BackendMock) ListMultipartUploadsCalls() []struct {
Output *s3.ListMultipartUploadsInput
ListMultipartUploadsInput *s3.ListMultipartUploadsInput
} {
var calls []struct {
Output *s3.ListMultipartUploadsInput
ListMultipartUploadsInput *s3.ListMultipartUploadsInput
}
mock.lockListMultipartUploads.RLock()
calls = mock.calls.ListMultipartUploads
@@ -1616,7 +1542,7 @@ func (mock *BackendMock) StringCalls() []struct {
}
// UploadPartCopy calls UploadPartCopyFunc.
func (mock *BackendMock) UploadPartCopy(uploadPartCopyInput *s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error) {
func (mock *BackendMock) UploadPartCopy(uploadPartCopyInput *s3.UploadPartCopyInput) (s3response.CopyObjectResult, error) {
if mock.UploadPartCopyFunc == nil {
panic("BackendMock.UploadPartCopyFunc: method is nil but Backend.UploadPartCopy was just called")
}

View File

@@ -33,6 +33,7 @@ import (
"github.com/versity/versitygw/backend"
"github.com/versity/versitygw/s3api/utils"
"github.com/versity/versitygw/s3err"
"github.com/versity/versitygw/s3response"
)
type S3ApiController struct {
@@ -40,8 +41,8 @@ type S3ApiController struct {
iam auth.IAMService
}
func New(be backend.Backend) S3ApiController {
return S3ApiController{be: be}
func New(be backend.Backend, iam auth.IAMService) S3ApiController {
return S3ApiController{be: be, iam: iam}
}
func (c S3ApiController) ListBuckets(ctx *fiber.Ctx) error {
@@ -77,6 +78,24 @@ func (c S3ApiController) GetActions(ctx *fiber.Ctx) error {
return SendResponse(ctx, err)
}
if ctx.Request().URI().QueryArgs().Has("tagging") {
if err := auth.VerifyACL(parsedAcl, bucket, access, "READ", isRoot); err != nil {
return SendXMLResponse(ctx, nil, err)
}
tags, err := c.be.GetTags(bucket, key)
if err != nil {
return SendXMLResponse(ctx, nil, err)
}
resp := s3response.Tagging{TagSet: s3response.TagSet{Tags: []s3response.Tag{}}}
for key, val := range tags {
resp.TagSet.Tags = append(resp.TagSet.Tags, s3response.Tag{Key: key, Value: val})
}
return SendXMLResponse(ctx, resp, nil)
}
if uploadId != "" {
if maxParts < 0 || (maxParts == 0 && ctx.Query("max-parts") != "") {
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidMaxParts))
@@ -216,8 +235,9 @@ func (c S3ApiController) ListActions(ctx *fiber.Ctx) error {
}
func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
bucket, acl, grantFullControl, grantRead, grantReadACP, granWrite, grantWriteACP, access, isRoot :=
bucket, bucketOwner, acl, grantFullControl, grantRead, grantReadACP, granWrite, grantWriteACP, access, isRoot :=
ctx.Params("bucket"),
ctx.Get("X-Amz-Expected-Bucket-Owner"),
ctx.Get("X-Amz-Acl"),
ctx.Get("X-Amz-Grant-Full-Control"),
ctx.Get("X-Amz-Grant-Read"),
@@ -229,13 +249,51 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
grants := grantFullControl + grantRead + grantReadACP + granWrite + grantWriteACP
if grants != "" || acl != "" {
if grants != "" && acl != "" {
return errors.New("wrong api call")
}
if ctx.Request().URI().QueryArgs().Has("acl") {
var input *s3.PutBucketAclInput
if acl != "" && acl != "private" && acl != "public-read" && acl != "public-read-write" {
return errors.New("wrong api call")
if len(ctx.Body()) > 0 {
if grants+acl != "" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
var accessControlPolicy auth.AccessControlPolicy
err := xml.Unmarshal(ctx.Body(), &accessControlPolicy)
if err != nil {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
input = &s3.PutBucketAclInput{
Bucket: &bucket,
ACL: "",
AccessControlPolicy: &types.AccessControlPolicy{Owner: &accessControlPolicy.Owner, Grants: accessControlPolicy.AccessControlList.Grants},
}
}
if acl != "" {
if acl != "private" && acl != "public-read" && acl != "public-read-write" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
if len(ctx.Body()) > 0 || grants != "" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
input = &s3.PutBucketAclInput{
Bucket: &bucket,
ACL: types.BucketCannedACL(acl),
AccessControlPolicy: &types.AccessControlPolicy{Owner: &types.Owner{ID: &bucketOwner}},
}
}
if grants != "" {
input = &s3.PutBucketAclInput{
Bucket: &bucket,
GrantFullControl: &grantFullControl,
GrantRead: &grantRead,
GrantReadACP: &grantReadACP,
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
AccessControlPolicy: &types.AccessControlPolicy{Owner: &types.Owner{ID: &bucketOwner}},
ACL: "",
}
}
data, err := c.be.GetBucketAcl(bucket)
@@ -252,18 +310,12 @@ func (c S3ApiController) PutBucketActions(ctx *fiber.Ctx) error {
return SendResponse(ctx, err)
}
input := &s3.PutBucketAclInput{
Bucket: &bucket,
ACL: types.BucketCannedACL(acl),
GrantFullControl: &grantFullControl,
GrantRead: &grantRead,
GrantReadACP: &grantReadACP,
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
AccessControlPolicy: &types.AccessControlPolicy{Owner: &types.Owner{ID: &access}},
updAcl, err := auth.UpdateACL(input, parsedAcl, c.iam)
if err != nil {
return SendResponse(ctx, err)
}
err = auth.UpdateACL(input, parsedAcl, c.iam)
err = c.be.PutBucketAcl(bucket, updAcl)
return SendResponse(ctx, err)
}
@@ -276,7 +328,6 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
keyStart := ctx.Params("key")
keyEnd := ctx.Params("*1")
uploadId := ctx.Query("uploadId")
partNumberStr := ctx.Query("partNumber")
access := ctx.Locals("access").(string)
isRoot := ctx.Locals("isRoot").(bool)
@@ -286,6 +337,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
copySrcIfNoneMatch := ctx.Get("X-Amz-Copy-Source-If-None-Match")
copySrcModifSince := ctx.Get("X-Amz-Copy-Source-If-Modified-Since")
copySrcUnmodifSince := ctx.Get("X-Amz-Copy-Source-If-Unmodified-Since")
copySrcRange := ctx.Get("X-Amz-Copy-Source-Range")
// Permission headers
acl := ctx.Get("X-Amz-Acl")
@@ -297,6 +349,7 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
// Other headers
contentLengthStr := ctx.Get("Content-Length")
bucketOwner := ctx.Get("X-Amz-Expected-Bucket-Owner")
grants := grantFullControl + grantRead + grantReadACP + granWrite + grantWriteACP
@@ -327,9 +380,49 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
return SendResponse(ctx, err)
}
if uploadId != "" && partNumberStr != "" {
if ctx.Request().URI().QueryArgs().Has("tagging") {
var objTagging s3response.Tagging
err := xml.Unmarshal(ctx.Body(), &objTagging)
if err != nil {
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
tags := make(map[string]string, len(objTagging.TagSet.Tags))
for _, tag := range objTagging.TagSet.Tags {
tags[tag.Key] = tag.Value
}
if err := auth.VerifyACL(parsedAcl, bucket, access, "WRITE", isRoot); err != nil {
return SendResponse(ctx, err)
}
err = c.be.SetTags(bucket, keyStart, tags)
return SendResponse(ctx, err)
}
if ctx.Request().URI().QueryArgs().Has("uploadId") && ctx.Request().URI().QueryArgs().Has("partNumber") && copySource != "" {
partNumber := ctx.QueryInt("partNumber", -1)
if partNumber < 1 {
if partNumber < 1 || partNumber > 10000 {
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart))
}
resp, err := c.be.UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: &bucket,
Key: &keyStart,
CopySource: &copySource,
PartNumber: int32(partNumber),
UploadId: &uploadId,
ExpectedBucketOwner: &bucketOwner,
CopySourceRange: &copySrcRange,
})
return SendXMLResponse(ctx, resp, err)
}
if ctx.Request().URI().QueryArgs().Has("uploadId") && ctx.Request().URI().QueryArgs().Has("partNumber") {
partNumber := ctx.QueryInt("partNumber", -1)
if partNumber < 1 || partNumber > 10000 {
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidPart))
}
@@ -345,25 +438,57 @@ func (c S3ApiController) PutActions(ctx *fiber.Ctx) error {
return SendResponse(ctx, err)
}
if grants != "" || acl != "" {
if grants != "" && acl != "" {
return errors.New("wrong api call")
if ctx.Request().URI().QueryArgs().Has("acl") {
var input *s3.PutObjectAclInput
if len(ctx.Body()) > 0 {
if grants+acl != "" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
var accessControlPolicy auth.AccessControlPolicy
err := xml.Unmarshal(ctx.Body(), &accessControlPolicy)
if err != nil {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
input = &s3.PutObjectAclInput{
Bucket: &bucket,
Key: &keyStart,
ACL: "",
AccessControlPolicy: &types.AccessControlPolicy{Owner: &accessControlPolicy.Owner, Grants: accessControlPolicy.AccessControlList.Grants},
}
}
if acl != "" {
if acl != "private" && acl != "public-read" && acl != "public-read-write" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
if len(ctx.Body()) > 0 || grants != "" {
return SendXMLResponse(ctx, nil, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
input = &s3.PutObjectAclInput{
Bucket: &bucket,
Key: &keyStart,
ACL: types.ObjectCannedACL(acl),
AccessControlPolicy: &types.AccessControlPolicy{Owner: &types.Owner{ID: &bucketOwner}},
}
}
if grants != "" {
input = &s3.PutObjectAclInput{
Bucket: &bucket,
Key: &keyStart,
GrantFullControl: &grantFullControl,
GrantRead: &grantRead,
GrantReadACP: &grantReadACP,
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
AccessControlPolicy: &types.AccessControlPolicy{Owner: &types.Owner{ID: &bucketOwner}},
ACL: "",
}
}
if err := auth.VerifyACL(parsedAcl, bucket, access, "WRITE_ACP", isRoot); err != nil {
return SendResponse(ctx, err)
}
err := c.be.PutObjectAcl(&s3.PutObjectAclInput{
Bucket: &bucket,
Key: &keyStart,
ACL: types.ObjectCannedACL(acl),
GrantFullControl: &grantFullControl,
GrantRead: &grantRead,
GrantReadACP: &grantReadACP,
GrantWrite: &granWrite,
GrantWriteACP: &grantWriteACP,
})
err = c.be.PutObjectAcl(input)
return SendResponse(ctx, err)
}
@@ -425,7 +550,7 @@ func (c S3ApiController) DeleteObjects(ctx *fiber.Ctx) error {
var dObj types.Delete
if err := xml.Unmarshal(ctx.Body(), &dObj); err != nil {
return errors.New("wrong api call")
return SendResponse(ctx, s3err.GetAPIError(s3err.ErrInvalidRequest))
}
data, err := c.be.GetBucketAcl(bucket)
@@ -468,6 +593,15 @@ func (c S3ApiController) DeleteActions(ctx *fiber.Ctx) error {
return SendResponse(ctx, err)
}
if ctx.Request().URI().QueryArgs().Has("tagging") {
if err := auth.VerifyACL(parsedAcl, bucket, access, "WRITE", isRoot); err != nil {
return SendResponse(ctx, err)
}
err = c.be.RemoveTags(bucket, key)
return SendResponse(ctx, err)
}
if uploadId != "" {
expectedBucketOwner, requestPayer := ctx.Get("X-Amz-Expected-Bucket-Owner"), ctx.Get("X-Amz-Request-Payer")
@@ -674,6 +808,7 @@ func SendResponse(ctx *fiber.Ctx, err error) error {
func SendXMLResponse(ctx *fiber.Ctx, resp any, err error) error {
if err != nil {
fmt.Println(err)
serr, ok := err.(s3err.APIError)
if ok {
ctx.Status(serr.HTTPStatusCode)

View File

@@ -16,6 +16,7 @@ package controllers
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -48,7 +49,8 @@ func init() {
func TestNew(t *testing.T) {
type args struct {
be backend.Backend
be backend.Backend
iam auth.IAMService
}
be := backend.BackendUnsupported{}
@@ -61,16 +63,18 @@ func TestNew(t *testing.T) {
{
name: "Initialize S3 api controller",
args: args{
be: be,
be: be,
iam: &auth.IAMServiceInternal{},
},
want: S3ApiController{
be: be,
be: be,
iam: &auth.IAMServiceInternal{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := New(tt.args.be); !reflect.DeepEqual(got, tt.want) {
if got := New(tt.args.be, tt.args.iam); !reflect.DeepEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want)
}
})
@@ -123,6 +127,16 @@ func TestS3ApiController_ListBuckets(t *testing.T) {
})
appErr.Get("/", s3ApiControllerErr.ListBuckets)
//Admin error case
admErr := fiber.New()
admErr.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("access", "valid access")
ctx.Locals("isRoot", false)
ctx.Locals("isDebug", false)
return ctx.Next()
})
admErr.Get("/", s3ApiController.ListBuckets)
tests := []struct {
name string
args args
@@ -148,6 +162,15 @@ func TestS3ApiController_ListBuckets(t *testing.T) {
wantErr: false,
statusCode: 200,
},
{
name: "admin-error-case",
args: args{
req: httptest.NewRequest(http.MethodGet, "/", nil),
},
app: admErr,
wantErr: false,
statusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -169,6 +192,11 @@ func TestS3ApiController_GetActions(t *testing.T) {
req *http.Request
}
getPtr := func(val string) *string {
return &val
}
now := time.Now()
app := fiber.New()
s3ApiController := S3ApiController{
be: &BackendMock{
@@ -185,7 +213,18 @@ func TestS3ApiController_GetActions(t *testing.T) {
return &s3.GetObjectAttributesOutput{}, nil
},
GetObjectFunc: func(bucket, object, acceptRange string, writer io.Writer) (*s3.GetObjectOutput, error) {
return &s3.GetObjectOutput{Metadata: nil}, nil
return &s3.GetObjectOutput{
Metadata: map[string]string{"hello": "world"},
ContentType: getPtr("application/xml"),
ContentEncoding: getPtr("gzip"),
ETag: getPtr("98sda7f97sa9df798sd79f8as9df"),
ContentLength: 1000,
LastModified: &now,
StorageClass: "storage class",
}, nil
},
GetTagsFunc: func(bucket, object string) (map[string]string, error) {
return map[string]string{"hello": "world"}, nil
},
},
}
@@ -197,17 +236,9 @@ func TestS3ApiController_GetActions(t *testing.T) {
})
app.Get("/:bucket/:key/*", s3ApiController.GetActions)
// GetObjectACL
getObjectACLReq := httptest.NewRequest(http.MethodGet, "/my-bucket/key", nil)
getObjectACLReq.Header.Set("X-Amz-Object-Attributes", "attrs")
// GetObject error case
getObjectReq := httptest.NewRequest(http.MethodGet, "/my-bucket/key", nil)
getObjectReq.Header.Set("Range", "hello=")
// GetObject success case
getObjectSuccessReq := httptest.NewRequest(http.MethodGet, "/my-bucket/key", nil)
getObjectReq.Header.Set("Range", "range=13-invalid")
// GetObjectAttributes success case
getObjAttrs := httptest.NewRequest(http.MethodGet, "/my-bucket/key", nil)
getObjAttrs.Header.Set("X-Amz-Object-Attributes", "hello")
tests := []struct {
name string
@@ -217,19 +248,46 @@ func TestS3ApiController_GetActions(t *testing.T) {
statusCode int
}{
{
name: "Get-actions-invalid-max-parts",
name: "Get-actions-get-tags-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=InvalidMaxParts", nil),
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key/key.json?tagging", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Get-actions-invalid-max-parts-string",
app: app,
args: args{
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=invalid", nil),
},
wantErr: false,
statusCode: 400,
},
{
name: "Get-actions-invalid-part-number-marker",
name: "Get-actions-invalid-max-parts-negative",
app: app,
args: args{
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=200&part-number-marker=InvalidPartNumber", nil),
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=-8", nil),
},
wantErr: false,
statusCode: 400,
},
{
name: "Get-actions-invalid-part-number-marker-string",
app: app,
args: args{
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=200&part-number-marker=invalid", nil),
},
wantErr: false,
statusCode: 400,
},
{
name: "Get-actions-invalid-part-number-marker-negative",
app: app,
args: args{
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?uploadId=hello&max-parts=200&part-number-marker=-8", nil),
},
wantErr: false,
statusCode: 400,
@@ -247,7 +305,16 @@ func TestS3ApiController_GetActions(t *testing.T) {
name: "Get-actions-get-object-acl-success",
app: app,
args: args{
req: getObjectACLReq,
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key?acl", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Get-actions-get-object-attributes-success",
app: app,
args: args{
req: getObjAttrs,
},
wantErr: false,
statusCode: 200,
@@ -256,7 +323,7 @@ func TestS3ApiController_GetActions(t *testing.T) {
name: "Get-actions-get-object-success",
app: app,
args: args{
req: getObjectSuccessReq,
req: httptest.NewRequest(http.MethodGet, "/my-bucket/key", nil),
},
wantErr: false,
statusCode: 200,
@@ -387,11 +454,11 @@ func TestS3ApiController_ListActions(t *testing.T) {
resp, err := tt.app.Test(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("S3ApiController.GetActions() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("S3ApiController.ListActions() error = %v, wantErr %v", err, tt.wantErr)
}
if resp.StatusCode != tt.statusCode {
t.Errorf("S3ApiController.GetActions() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
t.Errorf("S3ApiController.ListActions() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
}
})
}
@@ -403,6 +470,31 @@ func TestS3ApiController_PutBucketActions(t *testing.T) {
}
app := fiber.New()
// Mock valid acl
acl := auth.ACL{Owner: "valid access", ACL: "public-read-write"}
acldata, err := json.Marshal(acl)
if err != nil {
t.Errorf("Failed to parse the params: %v", err.Error())
return
}
body := `
<AccessControlPolicy xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<AccessControlList>
<Grant>
<Grantee>
<ID>hell</ID>
</Grantee>
<Permission>string</Permission>
</Grant>
</AccessControlList>
<Owner>
<ID>hello</ID>
</Owner>
</AccessControlPolicy>
`
s3ApiController := S3ApiController{
be: &BackendMock{
GetBucketAclFunc: func(bucket string) ([]byte, error) {
@@ -425,14 +517,28 @@ func TestS3ApiController_PutBucketActions(t *testing.T) {
})
app.Put("/:bucket", s3ApiController.PutBucketActions)
// Error case
errorReq := httptest.NewRequest(http.MethodPut, "/my-bucket", nil)
errorReq.Header.Set("X-Amz-Acl", "restricted")
errorReq.Header.Set("X-Amz-Grant-Read", "read")
// invalid acl case
invAclReq := httptest.NewRequest(http.MethodPut, "/my-bucket?acl", nil)
invAclReq.Header.Set("X-Amz-Acl", "invalid")
// PutBucketAcl success
aclReq := httptest.NewRequest(http.MethodPut, "/my-bucket", nil)
errorReq.Header.Set("X-Amz-Acl", "full")
// invalid acl case 2
errAclReq := httptest.NewRequest(http.MethodPut, "/my-bucket?acl", nil)
errAclReq.Header.Set("X-Amz-Acl", "private")
errAclReq.Header.Set("X-Amz-Grant-Read", "hello")
// PutBucketAcl incorrect bucket owner case
incorrectBucketOwner := httptest.NewRequest(http.MethodPut, "/my-bucket?acl", nil)
incorrectBucketOwner.Header.Set("X-Amz-Acl", "private")
incorrectBucketOwner.Header.Set("X-Amz-Expected-Bucket-Owner", "invalid access")
// PutBucketAcl acl success
aclSuccReq := httptest.NewRequest(http.MethodPut, "/my-bucket?acl", nil)
aclSuccReq.Header.Set("X-Amz-Acl", "private")
aclSuccReq.Header.Set("X-Amz-Expected-Bucket-Owner", "valid access")
// Invalid acl body case
errAclBodyReq := httptest.NewRequest(http.MethodPut, "/my-bucket?acl", strings.NewReader(body))
errAclBodyReq.Header.Set("X-Amz-Grant-Read", "hello")
tests := []struct {
name string
@@ -442,19 +548,46 @@ func TestS3ApiController_PutBucketActions(t *testing.T) {
statusCode int
}{
{
name: "Put-bucket-acl-error",
name: "Put-bucket-acl-invalid-acl",
app: app,
args: args{
req: errorReq,
req: invAclReq,
},
wantErr: false,
statusCode: 500,
statusCode: 400,
},
{
name: "Put-object-acl-success",
name: "Put-bucket-acl-incorrect-acl",
app: app,
args: args{
req: aclReq,
req: errAclReq,
},
wantErr: false,
statusCode: 400,
},
{
name: "Put-bucket-acl-incorrect-acl-body",
app: app,
args: args{
req: errAclBodyReq,
},
wantErr: false,
statusCode: 400,
},
{
name: "Put-bucket-acl-incorrect-bucket-owner",
app: app,
args: args{
req: incorrectBucketOwner,
},
wantErr: false,
statusCode: 403,
},
{
name: "Put-bucket-acl-success",
app: app,
args: args{
req: aclSuccReq,
},
wantErr: false,
statusCode: 200,
@@ -473,11 +606,11 @@ func TestS3ApiController_PutBucketActions(t *testing.T) {
resp, err := tt.app.Test(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("S3ApiController.GetActions() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("S3ApiController.PutBucketActions() error = %v, wantErr %v", err, tt.wantErr)
}
if resp.StatusCode != tt.statusCode {
t.Errorf("S3ApiController.GetActions() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
t.Errorf("S3ApiController.PutBucketActions() statusCode = %v, wantStatusCode = %v", resp.StatusCode, tt.statusCode)
}
}
}
@@ -487,15 +620,38 @@ func TestS3ApiController_PutActions(t *testing.T) {
req *http.Request
}
body := `
<AccessControlPolicy xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<AccessControlList>
<Grant>
<Grantee>
<ID>hell</ID>
</Grantee>
<Permission>string</Permission>
</Grant>
</AccessControlList>
<Owner>
<ID>hello</ID>
</Owner>
</AccessControlPolicy>
`
tagBody := `
<Tagging xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<TagSet>
<Tag>
<Key>string</Key>
<Value>string</Value>
</Tag>
</TagSet>
</Tagging>
`
app := fiber.New()
s3ApiController := S3ApiController{
be: &BackendMock{
GetBucketAclFunc: func(bucket string) ([]byte, error) {
return acldata, nil
},
UploadPartCopyFunc: func(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error) {
return &s3.UploadPartCopyOutput{}, nil
},
PutObjectAclFunc: func(*s3.PutObjectAclInput) error {
return nil
},
@@ -505,6 +661,15 @@ func TestS3ApiController_PutActions(t *testing.T) {
PutObjectFunc: func(*s3.PutObjectInput) (string, error) {
return "Hey", nil
},
PutObjectPartFunc: func(bucket, object, uploadID string, part int, length int64, r io.Reader) (string, error) {
return "hello", nil
},
SetTagsFunc: func(bucket, object string, tags map[string]string) error {
return nil
},
UploadPartCopyFunc: func(uploadPartCopyInput *s3.UploadPartCopyInput) (s3response.CopyObjectResult, error) {
return s3response.CopyObjectResult{}, nil
},
},
}
app.Use(func(ctx *fiber.Ctx) error {
@@ -515,19 +680,39 @@ func TestS3ApiController_PutActions(t *testing.T) {
})
app.Put("/:bucket/:key/*", s3ApiController.PutActions)
//PutObjectAcl error
aclReqErr := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil)
aclReqErr.Header.Set("X-Amz-Acl", "acl")
aclReqErr.Header.Set("X-Amz-Grant-Write", "write")
// UploadPartCopy success
uploadPartCpyReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?uploadId=12asd32&partNumber=3", nil)
uploadPartCpyReq.Header.Set("X-Amz-Copy-Source", "srcBucket/srcObject")
//PutObjectAcl success
aclReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil)
aclReq.Header.Set("X-Amz-Acl", "acl")
// UploadPartCopy error case
uploadPartCpyErrReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?uploadId=12asd32&partNumber=invalid", nil)
uploadPartCpyErrReq.Header.Set("X-Amz-Copy-Source", "srcBucket/srcObject")
//CopyObject success
// CopyObject success
cpySrcReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil)
cpySrcReq.Header.Set("X-Amz-Copy-Source", "srcBucket/srcObject")
// PutObjectAcl success
aclReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil)
aclReq.Header.Set("X-Amz-Acl", "private")
// PutObjectAcl success grt case
aclGrtReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil)
aclGrtReq.Header.Set("X-Amz-Grant-Read", "private")
// invalid acl case 1
invAclReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?acl", nil)
invAclReq.Header.Set("X-Amz-Acl", "invalid")
// invalid acl case 2
errAclReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?acl", nil)
errAclReq.Header.Set("X-Amz-Acl", "private")
errAclReq.Header.Set("X-Amz-Grant-Read", "hello")
// invalid body & grt case
invAclBodyGrtReq := httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?acl", strings.NewReader(body))
invAclBodyGrtReq.Header.Set("X-Amz-Grant-Read", "hello")
tests := []struct {
name string
app *fiber.App
@@ -536,7 +721,7 @@ func TestS3ApiController_PutActions(t *testing.T) {
statusCode int
}{
{
name: "Upload-put-part-error-case",
name: "Put-object-part-error-case",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?uploadId=abc&partNumber=invalid", nil),
@@ -545,40 +730,49 @@ func TestS3ApiController_PutActions(t *testing.T) {
statusCode: 400,
},
{
name: "Upload-copy-part-success",
name: "Put-object-part-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?partNumber=3", nil),
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?uploadId=4&partNumber=3", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Upload-part-success",
name: "Set-tags-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?uploadId=234234", nil),
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?tagging", strings.NewReader(tagBody)),
},
wantErr: false,
statusCode: 200,
},
{
name: "Put-object-acl-error",
name: "Put-object-acl-invalid-acl",
app: app,
args: args{
req: aclReqErr,
req: invAclReq,
},
wantErr: false,
statusCode: 500,
statusCode: 400,
},
{
name: "Put-object-acl-error",
name: "Put-object-acl-incorrect-acl",
app: app,
args: args{
req: aclReqErr,
req: errAclReq,
},
wantErr: false,
statusCode: 500,
statusCode: 400,
},
{
name: "Put-object-acl-incorrect-acl-body-case",
app: app,
args: args{
req: invAclBodyGrtReq,
},
wantErr: false,
statusCode: 400,
},
{
name: "Put-object-acl-success",
@@ -589,6 +783,42 @@ func TestS3ApiController_PutActions(t *testing.T) {
wantErr: false,
statusCode: 200,
},
{
name: "Put-object-acl-success-body-case",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key?acl", strings.NewReader(body)),
},
wantErr: false,
statusCode: 200,
},
{
name: "Put-object-acl-success-grt-case",
app: app,
args: args{
req: aclGrtReq,
},
wantErr: false,
statusCode: 200,
},
{
name: "Upload-part-copy-invalid-part-number",
app: app,
args: args{
req: uploadPartCpyErrReq,
},
wantErr: false,
statusCode: 400,
},
{
name: "Upload-part-copy-success",
app: app,
args: args{
req: uploadPartCpyReq,
},
wantErr: false,
statusCode: 200,
},
{
name: "Copy-object-success",
app: app,
@@ -602,7 +832,7 @@ func TestS3ApiController_PutActions(t *testing.T) {
name: "Put-object-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key", nil),
req: httptest.NewRequest(http.MethodPut, "/my-bucket/my-key/key2", nil),
},
wantErr: false,
statusCode: 200,
@@ -649,28 +879,6 @@ func TestS3ApiController_DeleteBucket(t *testing.T) {
app.Delete("/:bucket", s3ApiController.DeleteBucket)
// error case
appErr := fiber.New()
s3ApiControllerErr := S3ApiController{
be: &BackendMock{
GetBucketAclFunc: func(bucket string) ([]byte, error) {
return acldata, nil
},
DeleteBucketFunc: func(bucket string) error {
return s3err.GetAPIError(48)
},
},
}
appErr.Use(func(ctx *fiber.Ctx) error {
ctx.Locals("access", "valid access")
ctx.Locals("isRoot", true)
ctx.Locals("isDebug", false)
return ctx.Next()
})
appErr.Delete("/:bucket", s3ApiControllerErr.DeleteBucket)
tests := []struct {
name string
app *fiber.App
@@ -687,15 +895,6 @@ func TestS3ApiController_DeleteBucket(t *testing.T) {
wantErr: false,
statusCode: 200,
},
{
name: "Delete-bucket-error",
app: appErr,
args: args{
req: httptest.NewRequest(http.MethodDelete, "/my-bucket", nil),
},
wantErr: false,
statusCode: 400,
},
}
for _, tt := range tests {
resp, err := tt.app.Test(tt.args.req)
@@ -764,7 +963,7 @@ func TestS3ApiController_DeleteObjects(t *testing.T) {
req: httptest.NewRequest(http.MethodPost, "/my-bucket", nil),
},
wantErr: false,
statusCode: 500,
statusCode: 400,
},
}
for _, tt := range tests {
@@ -797,6 +996,9 @@ func TestS3ApiController_DeleteActions(t *testing.T) {
AbortMultipartUploadFunc: func(*s3.AbortMultipartUploadInput) error {
return nil
},
RemoveTagsFunc: func(bucket, object string) error {
return nil
},
},
}
@@ -808,7 +1010,7 @@ func TestS3ApiController_DeleteActions(t *testing.T) {
})
app.Delete("/:bucket/:key/*", s3ApiController.DeleteActions)
//Error case
// Error case
appErr := fiber.New()
s3ApiControllerErr := S3ApiController{be: &BackendMock{
@@ -826,7 +1028,7 @@ func TestS3ApiController_DeleteActions(t *testing.T) {
ctx.Locals("isDebug", false)
return ctx.Next()
})
appErr.Delete("/:bucket", s3ApiControllerErr.DeleteBucket)
appErr.Delete("/:bucket/:key/*", s3ApiControllerErr.DeleteActions)
tests := []struct {
name string
@@ -844,6 +1046,15 @@ func TestS3ApiController_DeleteActions(t *testing.T) {
wantErr: false,
statusCode: 200,
},
{
name: "Remove-object-tagging-success",
app: app,
args: args{
req: httptest.NewRequest(http.MethodDelete, "/my-bucket/my-key/key2?tagging", nil),
},
wantErr: false,
statusCode: 200,
},
{
name: "Delete-object-success",
app: app,
@@ -1271,6 +1482,16 @@ func Test_response(t *testing.T) {
wantErr: false,
statusCode: 500,
},
{
name: "Internal-server-error-not-api",
args: args{
ctx: &ctx,
resp: nil,
err: fmt.Errorf("custom error"),
},
wantErr: false,
statusCode: 500,
},
{
name: "Error-not-implemented",
args: args{

View File

@@ -0,0 +1,169 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package controllers
import (
"github.com/versity/versitygw/auth"
"sync"
)
// Ensure, that IAMServiceMock does implement auth.IAMService.
// If this is not the case, regenerate this file with moq.
var _ auth.IAMService = &IAMServiceMock{}
// IAMServiceMock is a mock implementation of auth.IAMService.
//
// func TestSomethingThatUsesIAMService(t *testing.T) {
//
// // make and configure a mocked auth.IAMService
// mockedIAMService := &IAMServiceMock{
// CreateAccountFunc: func(access string, account auth.Account) error {
// panic("mock out the CreateAccount method")
// },
// DeleteUserAccountFunc: func(access string) error {
// panic("mock out the DeleteUserAccount method")
// },
// GetUserAccountFunc: func(access string) (auth.Account, error) {
// panic("mock out the GetUserAccount method")
// },
// }
//
// // use mockedIAMService in code that requires auth.IAMService
// // and then make assertions.
//
// }
type IAMServiceMock struct {
// CreateAccountFunc mocks the CreateAccount method.
CreateAccountFunc func(access string, account auth.Account) error
// DeleteUserAccountFunc mocks the DeleteUserAccount method.
DeleteUserAccountFunc func(access string) error
// GetUserAccountFunc mocks the GetUserAccount method.
GetUserAccountFunc func(access string) (auth.Account, error)
// calls tracks calls to the methods.
calls struct {
// CreateAccount holds details about calls to the CreateAccount method.
CreateAccount []struct {
// Access is the access argument value.
Access string
// Account is the account argument value.
Account auth.Account
}
// DeleteUserAccount holds details about calls to the DeleteUserAccount method.
DeleteUserAccount []struct {
// Access is the access argument value.
Access string
}
// GetUserAccount holds details about calls to the GetUserAccount method.
GetUserAccount []struct {
// Access is the access argument value.
Access string
}
}
lockCreateAccount sync.RWMutex
lockDeleteUserAccount sync.RWMutex
lockGetUserAccount sync.RWMutex
}
// CreateAccount calls CreateAccountFunc.
func (mock *IAMServiceMock) CreateAccount(access string, account auth.Account) error {
if mock.CreateAccountFunc == nil {
panic("IAMServiceMock.CreateAccountFunc: method is nil but IAMService.CreateAccount was just called")
}
callInfo := struct {
Access string
Account auth.Account
}{
Access: access,
Account: account,
}
mock.lockCreateAccount.Lock()
mock.calls.CreateAccount = append(mock.calls.CreateAccount, callInfo)
mock.lockCreateAccount.Unlock()
return mock.CreateAccountFunc(access, account)
}
// CreateAccountCalls gets all the calls that were made to CreateAccount.
// Check the length with:
//
// len(mockedIAMService.CreateAccountCalls())
func (mock *IAMServiceMock) CreateAccountCalls() []struct {
Access string
Account auth.Account
} {
var calls []struct {
Access string
Account auth.Account
}
mock.lockCreateAccount.RLock()
calls = mock.calls.CreateAccount
mock.lockCreateAccount.RUnlock()
return calls
}
// DeleteUserAccount calls DeleteUserAccountFunc.
func (mock *IAMServiceMock) DeleteUserAccount(access string) error {
if mock.DeleteUserAccountFunc == nil {
panic("IAMServiceMock.DeleteUserAccountFunc: method is nil but IAMService.DeleteUserAccount was just called")
}
callInfo := struct {
Access string
}{
Access: access,
}
mock.lockDeleteUserAccount.Lock()
mock.calls.DeleteUserAccount = append(mock.calls.DeleteUserAccount, callInfo)
mock.lockDeleteUserAccount.Unlock()
return mock.DeleteUserAccountFunc(access)
}
// DeleteUserAccountCalls gets all the calls that were made to DeleteUserAccount.
// Check the length with:
//
// len(mockedIAMService.DeleteUserAccountCalls())
func (mock *IAMServiceMock) DeleteUserAccountCalls() []struct {
Access string
} {
var calls []struct {
Access string
}
mock.lockDeleteUserAccount.RLock()
calls = mock.calls.DeleteUserAccount
mock.lockDeleteUserAccount.RUnlock()
return calls
}
// GetUserAccount calls GetUserAccountFunc.
func (mock *IAMServiceMock) GetUserAccount(access string) (auth.Account, error) {
if mock.GetUserAccountFunc == nil {
panic("IAMServiceMock.GetUserAccountFunc: method is nil but IAMService.GetUserAccount was just called")
}
callInfo := struct {
Access string
}{
Access: access,
}
mock.lockGetUserAccount.Lock()
mock.calls.GetUserAccount = append(mock.calls.GetUserAccount, callInfo)
mock.lockGetUserAccount.Unlock()
return mock.GetUserAccountFunc(access)
}
// GetUserAccountCalls gets all the calls that were made to GetUserAccount.
// Check the length with:
//
// len(mockedIAMService.GetUserAccountCalls())
func (mock *IAMServiceMock) GetUserAccountCalls() []struct {
Access string
} {
var calls []struct {
Access string
}
mock.lockGetUserAccount.RLock()
calls = mock.calls.GetUserAccount
mock.lockGetUserAccount.RUnlock()
return calls
}

View File

@@ -50,15 +50,22 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, region string,
}
// Check the signature version
authParts := strings.Split(authorization, " ")
if len(authParts) < 4 {
authParts := strings.Split(authorization, ",")
for i, el := range authParts {
authParts[i] = strings.TrimSpace(el)
}
if len(authParts) != 3 {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields))
}
if authParts[0] != "AWS4-HMAC-SHA256" {
startParts := strings.Split(authParts[0], " ")
if startParts[0] != "AWS4-HMAC-SHA256" {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureVersionNotSupported))
}
credKv := strings.Split(authParts[1], "=")
credKv := strings.Split(startParts[1], "=")
if len(credKv) != 2 {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
@@ -67,7 +74,7 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, region string,
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
signHdrKv := strings.Split(authParts[2][:len(authParts[2])-1], "=")
signHdrKv := strings.Split(authParts[1], "=")
if len(signHdrKv) != 2 {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrCredMalformed))
}
@@ -93,15 +100,18 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, region string,
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMalformedDate))
}
// Calculate the hash of the request payload
hashedPayload := sha256.Sum256(ctx.Body())
hexPayload := hex.EncodeToString(hashedPayload[:])
hashPayloadHeader := ctx.Get("X-Amz-Content-Sha256")
ok := isSpecialPayload(hashPayloadHeader)
// Compare the calculated hash with the hash provided
if hashPayloadHeader != hexPayload {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch))
if !ok {
// Calculate the hash of the request payload
hashedPayload := sha256.Sum256(ctx.Body())
hexPayload := hex.EncodeToString(hashedPayload[:])
// Compare the calculated hash with the hash provided
if hashPayloadHeader != hexPayload {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrContentSHA256Mismatch))
}
}
// Create a new http request instance from fasthttp request
@@ -115,7 +125,7 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, region string,
signErr := signer.SignHTTP(req.Context(), aws.Credentials{
AccessKeyID: creds[0],
SecretAccessKey: account.Secret,
}, req, hexPayload, creds[3], region, tdate, func(options *v4.SignerOptions) {
}, req, hashPayloadHeader, creds[3], region, tdate, func(options *v4.SignerOptions) {
if debug {
options.LogSigning = true
options.Logger = logging.NewStandardLogger(os.Stderr)
@@ -130,7 +140,7 @@ func VerifyV4Signature(root RootUserConfig, iam auth.IAMService, region string,
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrMissingFields))
}
calculatedSign := strings.Split(parts[3], "=")[1]
expectedSign := strings.Split(authParts[3], "=")[1]
expectedSign := strings.Split(authParts[2], "=")[1]
if expectedSign != calculatedSign {
return controllers.SendResponse(ctx, s3err.GetAPIError(s3err.ErrSignatureDoesNotMatch))
@@ -159,3 +169,16 @@ func (a accounts) getAccount(access string) (auth.Account, error) {
return a.iam.GetUserAccount(access)
}
func isSpecialPayload(str string) bool {
specialValues := map[string]bool{
"UNSIGNED-PAYLOAD": true,
"STREAMING-UNSIGNED-PAYLOAD-TRAILER": true,
"STREAMING-AWS4-HMAC-SHA256-PAYLOAD": true,
"STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER": true,
"STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD": true,
"STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD-TRAILER": true,
}
return specialValues[str]
}

View File

@@ -24,14 +24,13 @@ import (
type S3ApiRouter struct{}
func (sa *S3ApiRouter) Init(app *fiber.App, be backend.Backend, iam auth.IAMService) {
s3ApiController := controllers.New(be)
s3ApiController := controllers.New(be, iam)
adminController := controllers.AdminController{IAMService: iam}
// TODO: think of better routing system
app.Post("/create-user", adminController.CreateUser)
app.Patch("/create-user", adminController.CreateUser)
// Admin Delete api
app.Delete("/delete-user", adminController.DeleteUser)
app.Patch("/delete-user", adminController.DeleteUser)
// ListBuckets action
app.Get("/", s3ApiController.ListBuckets)

View File

@@ -15,6 +15,7 @@
package s3api
import (
"crypto/tls"
"reflect"
"testing"
@@ -82,15 +83,26 @@ func TestS3ApiServer_Serve(t *testing.T) {
wantErr bool
}{
{
name: "Return error when serving S3 api server with invalid address",
name: "Serve-invalid-address",
wantErr: true,
sa: &S3ApiServer{
app: fiber.New(),
backend: backend.BackendUnsupported{},
port: "Wrong address",
port: "Invalid address",
router: &S3ApiRouter{},
},
},
{
name: "Serve-invalid-address-with-certificate",
wantErr: true,
sa: &S3ApiServer{
app: fiber.New(),
backend: backend.BackendUnsupported{},
port: "Invalid address",
router: &S3ApiRouter{},
cert: &tls.Certificate{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -117,3 +117,39 @@ func TestGetUserMetaData(t *testing.T) {
})
}
}
func Test_includeHeader(t *testing.T) {
type args struct {
hdr string
signedHdrs []string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "include-header-falsy-case",
args: args{
hdr: "Content-Type",
signedHdrs: []string{"X-Amz-Acl", "Content-Encoding"},
},
want: false,
},
{
name: "include-header-falsy-case",
args: args{
hdr: "Content-Type",
signedHdrs: []string{"X-Amz-Acl", "Content-Type"},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := includeHeader(tt.args.hdr, tt.args.signedHdrs); got != tt.want {
t.Errorf("includeHeader() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -94,3 +94,16 @@ type Owner struct {
ID string
DisplayName string
}
type Tag struct {
Key string `xml:"Key"`
Value string `xml:"Value"`
}
type TagSet struct {
Tags []Tag `xml:"Tag"`
}
type Tagging struct {
TagSet TagSet `xml:"TagSet"`
}