Files
at-container-registry/pkg/s3/mock.go

430 lines
11 KiB
Go

package s3
import (
"bytes"
"context"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
awss3 "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/google/uuid"
)
// MockS3Client implements S3Client for testing without real S3 credentials.
// It generates fake presigned URLs that point to a test server.
type MockS3Client struct {
// TestServerURL is the base URL for generating fake presigned URLs.
// Requests to these URLs should be handled by a test server (httptest.Server).
TestServerURL string
// UploadID is returned by CreateMultipartUpload.
// If empty, a UUID is generated.
UploadID string
// Objects stores in-memory blobs for PutObject/HeadObject/DeleteObject/CopyObject/ListObjectsV2.
Objects map[string][]byte
// Track calls for verification in tests
mu sync.Mutex
CreateMultipartCalls []CreateMultipartCall
CompleteCalls []CompleteCall
AbortCalls []AbortCall
UploadPartCalls []UploadPartCall
GetObjectCalls []GetObjectCall
HeadObjectCalls []HeadObjectCall
PutObjectCalls []PutObjectCall
// Error injection for testing error handling
CreateMultipartError error
CompleteError error
AbortError error
HeadObjectError error
CopyObjectError error
}
// CreateMultipartCall records a CreateMultipartUpload call
type CreateMultipartCall struct {
Bucket string
Key string
}
// CompleteCall records a CompleteMultipartUpload call
type CompleteCall struct {
Bucket string
Key string
UploadID string
Parts int
}
// AbortCall records an AbortMultipartUpload call
type AbortCall struct {
Bucket string
Key string
UploadID string
}
// UploadPartCall records a PresignUploadPart call
type UploadPartCall struct {
Bucket string
Key string
UploadID string
PartNumber int32
}
// GetObjectCall records a PresignGetObject call
type GetObjectCall struct {
Bucket string
Key string
}
// HeadObjectCall records a PresignHeadObject call
type HeadObjectCall struct {
Bucket string
Key string
}
// PutObjectCall records a PresignPutObject call
type PutObjectCall struct {
Bucket string
Key string
}
// NewMockS3Client creates a new mock S3 client for testing
func NewMockS3Client(testServerURL string) *MockS3Client {
return &MockS3Client{
TestServerURL: testServerURL,
Objects: make(map[string][]byte),
CreateMultipartCalls: []CreateMultipartCall{},
CompleteCalls: []CompleteCall{},
AbortCalls: []AbortCall{},
UploadPartCalls: []UploadPartCall{},
GetObjectCalls: []GetObjectCall{},
HeadObjectCalls: []HeadObjectCall{},
PutObjectCalls: []PutObjectCall{},
}
}
// CreateMultipartUpload implements S3Client
func (m *MockS3Client) CreateMultipartUpload(ctx context.Context, input *awss3.CreateMultipartUploadInput, opts ...func(*awss3.Options)) (*awss3.CreateMultipartUploadOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.CreateMultipartCalls = append(m.CreateMultipartCalls, CreateMultipartCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
})
if m.CreateMultipartError != nil {
return nil, m.CreateMultipartError
}
uploadID := m.UploadID
if uploadID == "" {
uploadID = "mock-upload-" + uuid.New().String()
}
return &awss3.CreateMultipartUploadOutput{
UploadId: aws.String(uploadID),
}, nil
}
// CompleteMultipartUpload implements S3Client
func (m *MockS3Client) CompleteMultipartUpload(ctx context.Context, input *awss3.CompleteMultipartUploadInput, opts ...func(*awss3.Options)) (*awss3.CompleteMultipartUploadOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
partsCount := 0
if input.MultipartUpload != nil {
partsCount = len(input.MultipartUpload.Parts)
}
m.CompleteCalls = append(m.CompleteCalls, CompleteCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
UploadID: aws.ToString(input.UploadId),
Parts: partsCount,
})
if m.CompleteError != nil {
return nil, m.CompleteError
}
// Store a placeholder object at the key so Stat/HeadObject works after complete
key := aws.ToString(input.Key)
if m.Objects != nil {
if _, exists := m.Objects[key]; !exists {
m.Objects[key] = []byte("completed-multipart")
}
}
// Return a mock ETag
etag := "\"mock-etag-" + uuid.New().String() + "\""
return &awss3.CompleteMultipartUploadOutput{
ETag: aws.String(etag),
}, nil
}
// AbortMultipartUpload implements S3Client
func (m *MockS3Client) AbortMultipartUpload(ctx context.Context, input *awss3.AbortMultipartUploadInput, opts ...func(*awss3.Options)) (*awss3.AbortMultipartUploadOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.AbortCalls = append(m.AbortCalls, AbortCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
UploadID: aws.ToString(input.UploadId),
})
if m.AbortError != nil {
return nil, m.AbortError
}
return &awss3.AbortMultipartUploadOutput{}, nil
}
// HeadObject implements S3Client
func (m *MockS3Client) HeadObject(ctx context.Context, input *awss3.HeadObjectInput, opts ...func(*awss3.Options)) (*awss3.HeadObjectOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.HeadObjectError != nil {
return nil, m.HeadObjectError
}
key := aws.ToString(input.Key)
data, ok := m.Objects[key]
if !ok {
return nil, fmt.Errorf("NoSuchKey: object %s not found", key)
}
size := int64(len(data))
return &awss3.HeadObjectOutput{
ContentLength: &size,
}, nil
}
// PutObject implements S3Client
func (m *MockS3Client) PutObject(ctx context.Context, input *awss3.PutObjectInput, opts ...func(*awss3.Options)) (*awss3.PutObjectOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
key := aws.ToString(input.Key)
m.PutObjectCalls = append(m.PutObjectCalls, PutObjectCall{
Bucket: aws.ToString(input.Bucket),
Key: key,
})
if input.Body != nil {
data, err := io.ReadAll(input.Body)
if err != nil {
return nil, err
}
m.Objects[key] = data
} else {
m.Objects[key] = []byte{}
}
return &awss3.PutObjectOutput{}, nil
}
// CopyObject implements S3Client
func (m *MockS3Client) CopyObject(ctx context.Context, input *awss3.CopyObjectInput, opts ...func(*awss3.Options)) (*awss3.CopyObjectOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.CopyObjectError != nil {
return nil, m.CopyObjectError
}
// CopySource is "bucket/key"
copySource := aws.ToString(input.CopySource)
// Strip bucket prefix to get key
parts := strings.SplitN(copySource, "/", 2)
srcKey := copySource
if len(parts) == 2 {
srcKey = parts[1]
}
data, ok := m.Objects[srcKey]
if !ok {
return nil, fmt.Errorf("NoSuchKey: source object %s not found", srcKey)
}
dstKey := aws.ToString(input.Key)
m.Objects[dstKey] = append([]byte{}, data...)
return &awss3.CopyObjectOutput{}, nil
}
// DeleteObject implements S3Client
func (m *MockS3Client) DeleteObject(ctx context.Context, input *awss3.DeleteObjectInput, opts ...func(*awss3.Options)) (*awss3.DeleteObjectOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
key := aws.ToString(input.Key)
delete(m.Objects, key)
return &awss3.DeleteObjectOutput{}, nil
}
// ListObjectsV2 implements S3Client
func (m *MockS3Client) ListObjectsV2(ctx context.Context, input *awss3.ListObjectsV2Input, opts ...func(*awss3.Options)) (*awss3.ListObjectsV2Output, error) {
m.mu.Lock()
defer m.mu.Unlock()
prefix := aws.ToString(input.Prefix)
delimiter := aws.ToString(input.Delimiter)
var contents []s3types.Object
commonPrefixes := map[string]bool{}
for key, data := range m.Objects {
if !strings.HasPrefix(key, prefix) {
continue
}
if delimiter != "" {
// Check if there's a delimiter after the prefix
rest := strings.TrimPrefix(key, prefix)
idx := strings.Index(rest, delimiter)
if idx >= 0 {
// Has delimiter — this is a common prefix, not a content object
cp := prefix + rest[:idx+len(delimiter)]
commonPrefixes[cp] = true
continue
}
}
size := int64(len(data))
k := key
contents = append(contents, s3types.Object{
Key: &k,
Size: &size,
})
}
var cps []s3types.CommonPrefix
for cp := range commonPrefixes {
p := cp
cps = append(cps, s3types.CommonPrefix{Prefix: &p})
}
falseVal := false
return &awss3.ListObjectsV2Output{
Contents: contents,
CommonPrefixes: cps,
IsTruncated: &falseVal,
}, nil
}
// PresignUploadPart implements S3Client
// Returns a mock presigned URL for test server
func (m *MockS3Client) PresignUploadPart(ctx context.Context, input *awss3.UploadPartInput, expires time.Duration) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.UploadPartCalls = append(m.UploadPartCalls, UploadPartCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
UploadID: aws.ToString(input.UploadId),
PartNumber: aws.ToInt32(input.PartNumber),
})
url := fmt.Sprintf("%s/upload/%s?partNumber=%d&uploadId=%s",
m.TestServerURL,
aws.ToString(input.Key),
aws.ToInt32(input.PartNumber),
aws.ToString(input.UploadId))
return url, nil
}
// PresignGetObject implements S3Client
func (m *MockS3Client) PresignGetObject(ctx context.Context, input *awss3.GetObjectInput, expires time.Duration) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.GetObjectCalls = append(m.GetObjectCalls, GetObjectCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
})
url := fmt.Sprintf("%s/get/%s", m.TestServerURL, aws.ToString(input.Key))
return url, nil
}
// PresignHeadObject implements S3Client
func (m *MockS3Client) PresignHeadObject(ctx context.Context, input *awss3.HeadObjectInput, expires time.Duration) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.HeadObjectCalls = append(m.HeadObjectCalls, HeadObjectCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
})
url := fmt.Sprintf("%s/head/%s", m.TestServerURL, aws.ToString(input.Key))
return url, nil
}
// PresignPutObject implements S3Client
func (m *MockS3Client) PresignPutObject(ctx context.Context, input *awss3.PutObjectInput, expires time.Duration) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.PutObjectCalls = append(m.PutObjectCalls, PutObjectCall{
Bucket: aws.ToString(input.Bucket),
Key: aws.ToString(input.Key),
})
// Also store the body if provided (for PresignPutObject used in tests that also check objects)
if input.Body != nil {
key := aws.ToString(input.Key)
data, _ := io.ReadAll(input.Body)
m.Objects[key] = data
}
url := fmt.Sprintf("%s/put/%s", m.TestServerURL, aws.ToString(input.Key))
return url, nil
}
// SetObject is a test helper to pre-populate an object in the mock store.
func (m *MockS3Client) SetObject(key string, data []byte) {
m.mu.Lock()
defer m.mu.Unlock()
m.Objects[key] = append([]byte{}, data...)
}
// GetObject implements S3Client
func (m *MockS3Client) GetObject(ctx context.Context, input *awss3.GetObjectInput, opts ...func(*awss3.Options)) (*awss3.GetObjectOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
key := aws.ToString(input.Key)
data, ok := m.Objects[key]
if !ok {
return nil, fmt.Errorf("NoSuchKey: %s", key)
}
return &awss3.GetObjectOutput{
Body: io.NopCloser(bytes.NewReader(bytes.Clone(data))),
ContentLength: aws.Int64(int64(len(data))),
}, nil
}
// GetObjectBytes is a test helper to read an object from the mock store (nil if not found).
func (m *MockS3Client) GetObjectBytes(key string) []byte {
m.mu.Lock()
defer m.mu.Unlock()
data, ok := m.Objects[key]
if !ok {
return nil
}
return bytes.Clone(data)
}