mirror of
https://github.com/versity/versitygw.git
synced 2025-12-23 05:05:16 +00:00
feat: Closes #431, Refactored aws signer: removed unnecessary codes, fixed staticcheck errors
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
AWS SDK for Go
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2014-2015 Stripe, Inc.
|
||||
Copyright 2024 Versity Software
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/aws/smithy-go/auth"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// HTTPAuthScheme is the SDK's internal implementation of smithyhttp.AuthScheme
|
||||
// for pre-existing implementations where the signer was added to client
|
||||
// config. SDK clients will key off of this type and ensure per-operation
|
||||
// updates to those signers persist on the scheme itself.
|
||||
type HTTPAuthScheme struct {
|
||||
schemeID string
|
||||
signer smithyhttp.Signer
|
||||
}
|
||||
|
||||
var _ smithyhttp.AuthScheme = (*HTTPAuthScheme)(nil)
|
||||
|
||||
// NewHTTPAuthScheme returns an auth scheme instance with the given config.
|
||||
func NewHTTPAuthScheme(schemeID string, signer smithyhttp.Signer) *HTTPAuthScheme {
|
||||
return &HTTPAuthScheme{
|
||||
schemeID: schemeID,
|
||||
signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
// SchemeID identifies the auth scheme.
|
||||
func (s *HTTPAuthScheme) SchemeID() string {
|
||||
return s.schemeID
|
||||
}
|
||||
|
||||
// IdentityResolver gets the identity resolver for the auth scheme.
|
||||
func (s *HTTPAuthScheme) IdentityResolver(o auth.IdentityResolverOptions) auth.IdentityResolver {
|
||||
return o.GetIdentityResolver(s.schemeID)
|
||||
}
|
||||
|
||||
// Signer gets the signer for the auth scheme.
|
||||
func (s *HTTPAuthScheme) Signer() smithyhttp.Signer {
|
||||
return s.signer
|
||||
}
|
||||
|
||||
// WithSigner returns a new instance of the auth scheme with the updated signer.
|
||||
func (s *HTTPAuthScheme) WithSigner(signer smithyhttp.Signer) *HTTPAuthScheme {
|
||||
return NewHTTPAuthScheme(s.schemeID, signer)
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
smithy "github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// SigV4 is a constant representing
|
||||
// Authentication Scheme Signature Version 4
|
||||
const SigV4 = "sigv4"
|
||||
|
||||
// SigV4A is a constant representing
|
||||
// Authentication Scheme Signature Version 4A
|
||||
const SigV4A = "sigv4a"
|
||||
|
||||
// SigV4S3Express identifies the S3 S3Express auth scheme.
|
||||
const SigV4S3Express = "sigv4-s3express"
|
||||
|
||||
// None is a constant representing the
|
||||
// None Authentication Scheme
|
||||
const None = "none"
|
||||
|
||||
// SupportedSchemes is a data structure
|
||||
// that indicates the list of supported AWS
|
||||
// authentication schemes
|
||||
var SupportedSchemes = map[string]bool{
|
||||
SigV4: true,
|
||||
SigV4A: true,
|
||||
SigV4S3Express: true,
|
||||
None: true,
|
||||
}
|
||||
|
||||
// AuthenticationScheme is a representation of
|
||||
// AWS authentication schemes
|
||||
type AuthenticationScheme interface {
|
||||
isAuthenticationScheme()
|
||||
}
|
||||
|
||||
// AuthenticationSchemeV4 is a AWS SigV4 representation
|
||||
type AuthenticationSchemeV4 struct {
|
||||
Name string
|
||||
SigningName *string
|
||||
SigningRegion *string
|
||||
DisableDoubleEncoding *bool
|
||||
}
|
||||
|
||||
func (a *AuthenticationSchemeV4) isAuthenticationScheme() {}
|
||||
|
||||
// AuthenticationSchemeV4A is a AWS SigV4A representation
|
||||
type AuthenticationSchemeV4A struct {
|
||||
Name string
|
||||
SigningName *string
|
||||
SigningRegionSet []string
|
||||
DisableDoubleEncoding *bool
|
||||
}
|
||||
|
||||
func (a *AuthenticationSchemeV4A) isAuthenticationScheme() {}
|
||||
|
||||
// AuthenticationSchemeNone is a representation for the none auth scheme
|
||||
type AuthenticationSchemeNone struct{}
|
||||
|
||||
func (a *AuthenticationSchemeNone) isAuthenticationScheme() {}
|
||||
|
||||
// NoAuthenticationSchemesFoundError is used in signaling
|
||||
// that no authentication schemes have been specified.
|
||||
type NoAuthenticationSchemesFoundError struct{}
|
||||
|
||||
func (e *NoAuthenticationSchemesFoundError) Error() string {
|
||||
return fmt.Sprint("No authentication schemes specified.")
|
||||
}
|
||||
|
||||
// UnSupportedAuthenticationSchemeSpecifiedError is used in
|
||||
// signaling that only unsupported authentication schemes
|
||||
// were specified.
|
||||
type UnSupportedAuthenticationSchemeSpecifiedError struct {
|
||||
UnsupportedSchemes []string
|
||||
}
|
||||
|
||||
func (e *UnSupportedAuthenticationSchemeSpecifiedError) Error() string {
|
||||
return fmt.Sprint("Unsupported authentication scheme specified.")
|
||||
}
|
||||
|
||||
// GetAuthenticationSchemes extracts the relevant authentication scheme data
|
||||
// into a custom strongly typed Go data structure.
|
||||
func GetAuthenticationSchemes(p *smithy.Properties) ([]AuthenticationScheme, error) {
|
||||
var result []AuthenticationScheme
|
||||
if !p.Has("authSchemes") {
|
||||
return nil, &NoAuthenticationSchemesFoundError{}
|
||||
}
|
||||
|
||||
authSchemes, _ := p.Get("authSchemes").([]interface{})
|
||||
|
||||
var unsupportedSchemes []string
|
||||
for _, scheme := range authSchemes {
|
||||
authScheme, _ := scheme.(map[string]interface{})
|
||||
|
||||
version := authScheme["name"].(string)
|
||||
switch version {
|
||||
case SigV4, SigV4S3Express:
|
||||
v4Scheme := AuthenticationSchemeV4{
|
||||
Name: version,
|
||||
SigningName: getSigningName(authScheme),
|
||||
SigningRegion: getSigningRegion(authScheme),
|
||||
DisableDoubleEncoding: getDisableDoubleEncoding(authScheme),
|
||||
}
|
||||
result = append(result, AuthenticationScheme(&v4Scheme))
|
||||
case SigV4A:
|
||||
v4aScheme := AuthenticationSchemeV4A{
|
||||
Name: SigV4A,
|
||||
SigningName: getSigningName(authScheme),
|
||||
SigningRegionSet: getSigningRegionSet(authScheme),
|
||||
DisableDoubleEncoding: getDisableDoubleEncoding(authScheme),
|
||||
}
|
||||
result = append(result, AuthenticationScheme(&v4aScheme))
|
||||
case None:
|
||||
noneScheme := AuthenticationSchemeNone{}
|
||||
result = append(result, AuthenticationScheme(&noneScheme))
|
||||
default:
|
||||
unsupportedSchemes = append(unsupportedSchemes, authScheme["name"].(string))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, &UnSupportedAuthenticationSchemeSpecifiedError{
|
||||
UnsupportedSchemes: unsupportedSchemes,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type disableDoubleEncoding struct{}
|
||||
|
||||
// SetDisableDoubleEncoding sets or modifies the disable double encoding option
|
||||
// on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetDisableDoubleEncoding(ctx context.Context, value bool) context.Context {
|
||||
return middleware.WithStackValue(ctx, disableDoubleEncoding{}, value)
|
||||
}
|
||||
|
||||
// GetDisableDoubleEncoding retrieves the disable double encoding option
|
||||
// from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetDisableDoubleEncoding(ctx context.Context) (value bool, ok bool) {
|
||||
value, ok = middleware.GetStackValue(ctx, disableDoubleEncoding{}).(bool)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func getSigningName(authScheme map[string]interface{}) *string {
|
||||
signingName, ok := authScheme["signingName"].(string)
|
||||
if !ok || signingName == "" {
|
||||
return nil
|
||||
}
|
||||
return &signingName
|
||||
}
|
||||
|
||||
func getSigningRegionSet(authScheme map[string]interface{}) []string {
|
||||
untypedSigningRegionSet, ok := authScheme["signingRegionSet"].([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
signingRegionSet := []string{}
|
||||
for _, item := range untypedSigningRegionSet {
|
||||
signingRegionSet = append(signingRegionSet, item.(string))
|
||||
}
|
||||
return signingRegionSet
|
||||
}
|
||||
|
||||
func getSigningRegion(authScheme map[string]interface{}) *string {
|
||||
signingRegion, ok := authScheme["signingRegion"].(string)
|
||||
if !ok || signingRegion == "" {
|
||||
return nil
|
||||
}
|
||||
return &signingRegion
|
||||
}
|
||||
|
||||
func getDisableDoubleEncoding(authScheme map[string]interface{}) *bool {
|
||||
disableDoubleEncoding, ok := authScheme["disableDoubleEncoding"].(bool)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &disableDoubleEncoding
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
smithy "github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
func TestV4(t *testing.T) {
|
||||
|
||||
propsV4 := smithy.Properties{}
|
||||
|
||||
propsV4.Set("authSchemes", interface{}([]interface{}{
|
||||
map[string]interface{}{
|
||||
"disableDoubleEncoding": true,
|
||||
"name": "sigv4",
|
||||
"signingName": "s3",
|
||||
"signingRegion": "us-west-2",
|
||||
},
|
||||
}))
|
||||
|
||||
result, err := GetAuthenticationSchemes(&propsV4)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
_, ok := result[0].(AuthenticationScheme)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationScheme. %v", result[0])
|
||||
}
|
||||
|
||||
v4Scheme, ok := result[0].(*AuthenticationSchemeV4)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4. %v", result[0])
|
||||
}
|
||||
|
||||
if v4Scheme.Name != "sigv4" {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4 signer version name")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestV4A(t *testing.T) {
|
||||
|
||||
propsV4A := smithy.Properties{}
|
||||
|
||||
propsV4A.Set("authSchemes", []interface{}{
|
||||
map[string]interface{}{
|
||||
"disableDoubleEncoding": true,
|
||||
"name": "sigv4a",
|
||||
"signingName": "s3",
|
||||
"signingRegionSet": []string{"*"},
|
||||
},
|
||||
})
|
||||
|
||||
result, err := GetAuthenticationSchemes(&propsV4A)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
_, ok := result[0].(AuthenticationScheme)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationScheme. %v", result[0])
|
||||
}
|
||||
|
||||
v4AScheme, ok := result[0].(*AuthenticationSchemeV4A)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4A. %v", result[0])
|
||||
}
|
||||
|
||||
if v4AScheme.Name != "sigv4a" {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4A signer version name")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestV4S3Express(t *testing.T) {
|
||||
props := smithy.Properties{}
|
||||
props.Set("authSchemes", []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": SigV4S3Express,
|
||||
"signingName": "s3",
|
||||
"signingRegion": "us-east-1",
|
||||
"disableDoubleEncoding": true,
|
||||
},
|
||||
})
|
||||
|
||||
result, err := GetAuthenticationSchemes(&props)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
scheme, ok := result[0].(*AuthenticationSchemeV4)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4. %v", result[0])
|
||||
}
|
||||
|
||||
if scheme.Name != SigV4S3Express {
|
||||
t.Fatalf("expected %s, got %s", SigV4S3Express, scheme.Name)
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
)
|
||||
|
||||
// BearerTokenAdapter adapts smithy bearer.Token to smithy auth.Identity.
|
||||
type BearerTokenAdapter struct {
|
||||
Token bearer.Token
|
||||
}
|
||||
|
||||
var _ auth.Identity = (*BearerTokenAdapter)(nil)
|
||||
|
||||
// Expiration returns the time of expiration for the token.
|
||||
func (v *BearerTokenAdapter) Expiration() time.Time {
|
||||
return v.Token.Expires
|
||||
}
|
||||
|
||||
// BearerTokenProviderAdapter adapts smithy bearer.TokenProvider to smithy
|
||||
// auth.IdentityResolver.
|
||||
type BearerTokenProviderAdapter struct {
|
||||
Provider bearer.TokenProvider
|
||||
}
|
||||
|
||||
var _ (auth.IdentityResolver) = (*BearerTokenProviderAdapter)(nil)
|
||||
|
||||
// GetIdentity retrieves a bearer token using the underlying provider.
|
||||
func (v *BearerTokenProviderAdapter) GetIdentity(ctx context.Context, _ smithy.Properties) (
|
||||
auth.Identity, error,
|
||||
) {
|
||||
token, err := v.Provider.RetrieveBearerToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get token: %w", err)
|
||||
}
|
||||
|
||||
return &BearerTokenAdapter{Token: token}, nil
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// BearerTokenSignerAdapter adapts smithy bearer.Signer to smithy http
|
||||
// auth.Signer.
|
||||
type BearerTokenSignerAdapter struct {
|
||||
Signer bearer.Signer
|
||||
}
|
||||
|
||||
var _ (smithyhttp.Signer) = (*BearerTokenSignerAdapter)(nil)
|
||||
|
||||
// SignRequest signs the request with the provided bearer token.
|
||||
func (v *BearerTokenSignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, identity auth.Identity, _ smithy.Properties) error {
|
||||
ca, ok := identity.(*BearerTokenAdapter)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected identity type: %T", identity)
|
||||
}
|
||||
|
||||
signed, err := v.Signer.SignWithBearerToken(ctx, ca.Token, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign request: %w", err)
|
||||
}
|
||||
|
||||
*r = *signed.(*smithyhttp.Request)
|
||||
return nil
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
)
|
||||
|
||||
// CredentialsAdapter adapts aws.Credentials to auth.Identity.
|
||||
type CredentialsAdapter struct {
|
||||
Credentials aws.Credentials
|
||||
}
|
||||
|
||||
var _ auth.Identity = (*CredentialsAdapter)(nil)
|
||||
|
||||
// Expiration returns the time of expiration for the credentials.
|
||||
func (v *CredentialsAdapter) Expiration() time.Time {
|
||||
return v.Credentials.Expires
|
||||
}
|
||||
|
||||
// CredentialsProviderAdapter adapts aws.CredentialsProvider to auth.IdentityResolver.
|
||||
type CredentialsProviderAdapter struct {
|
||||
Provider aws.CredentialsProvider
|
||||
}
|
||||
|
||||
var _ (auth.IdentityResolver) = (*CredentialsProviderAdapter)(nil)
|
||||
|
||||
// GetIdentity retrieves AWS credentials using the underlying provider.
|
||||
func (v *CredentialsProviderAdapter) GetIdentity(ctx context.Context, _ smithy.Properties) (
|
||||
auth.Identity, error,
|
||||
) {
|
||||
if v.Provider == nil {
|
||||
return &CredentialsAdapter{Credentials: aws.Credentials{}}, nil
|
||||
}
|
||||
|
||||
creds, err := v.Provider.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get credentials: %w", err)
|
||||
}
|
||||
|
||||
return &CredentialsAdapter{Credentials: creds}, nil
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
// Package smithy adapts concrete AWS auth and signing types to the generic smithy versions.
|
||||
package smithy
|
||||
@@ -1,53 +0,0 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
)
|
||||
|
||||
// V4SignerAdapter adapts v4.HTTPSigner to smithy http.Signer.
|
||||
type V4SignerAdapter struct {
|
||||
Signer v4.HTTPSigner
|
||||
Logger logging.Logger
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
var _ (smithyhttp.Signer) = (*V4SignerAdapter)(nil)
|
||||
|
||||
// SignRequest signs the request with the provided identity.
|
||||
func (v *V4SignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, identity auth.Identity, props smithy.Properties) error {
|
||||
ca, ok := identity.(*CredentialsAdapter)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected identity type: %T", identity)
|
||||
}
|
||||
|
||||
name, ok := smithyhttp.GetSigV4SigningName(&props)
|
||||
if !ok {
|
||||
return fmt.Errorf("sigv4 signing name is required")
|
||||
}
|
||||
|
||||
region, ok := smithyhttp.GetSigV4SigningRegion(&props)
|
||||
if !ok {
|
||||
return fmt.Errorf("sigv4 signing region is required")
|
||||
}
|
||||
|
||||
hash := v4.GetPayloadHash(ctx)
|
||||
err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, region, sdk.NowTime(), func(o *v4.SignerOptions) {
|
||||
o.DisableURIPathEscaping, _ = smithyhttp.GetDisableDoubleEncoding(&props)
|
||||
|
||||
o.Logger = v.Logger
|
||||
o.LogSigning = v.LogSigning
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign http: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Match is a testing helper to test for testing error by comparing expected
|
||||
// with a regular expression.
|
||||
func Match(t *testing.T, regex, expected string) {
|
||||
t.Helper()
|
||||
|
||||
if !regexp.MustCompile(regex).Match([]byte(expected)) {
|
||||
t.Errorf("%q\n\tdoes not match /%s/", expected, regex)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertURL verifies the expected URL is matches the actual.
|
||||
func AssertURL(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectURL, err := url.Parse(expect)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected URL", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
actualURL, err := url.Parse(actual)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual URL", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
|
||||
equal(t, expectURL.Host, actualURL.Host, msgAndArgs...)
|
||||
equal(t, expectURL.Scheme, actualURL.Scheme, msgAndArgs...)
|
||||
equal(t, expectURL.Path, actualURL.Path, msgAndArgs...)
|
||||
|
||||
return AssertQuery(t, expectURL.Query().Encode(), actualURL.Query().Encode(), msgAndArgs...)
|
||||
}
|
||||
|
||||
var queryMapKey = regexp.MustCompile(`(.*?)\.[0-9]+\.key`)
|
||||
|
||||
// AssertQuery verifies the expect HTTP query string matches the actual.
|
||||
func AssertQuery(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectQ, err := url.ParseQuery(expect)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected Query", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
actualQ, err := url.ParseQuery(actual)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual Query", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
|
||||
// Make sure the keys are the same
|
||||
if !equal(t, queryValueKeys(expectQ), queryValueKeys(actualQ), msgAndArgs...) {
|
||||
return false
|
||||
}
|
||||
|
||||
keys := map[string][]string{}
|
||||
for key, v := range expectQ {
|
||||
if queryMapKey.Match([]byte(key)) {
|
||||
submatch := queryMapKey.FindStringSubmatch(key)
|
||||
keys[submatch[1]] = append(keys[submatch[1]], v...)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range keys {
|
||||
// clear all keys that have prefix
|
||||
for key := range expectQ {
|
||||
if strings.HasPrefix(key, k) {
|
||||
delete(expectQ, key)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(v)
|
||||
for i, value := range v {
|
||||
expectQ[fmt.Sprintf("%s.%d.key", k, i+1)] = []string{value}
|
||||
}
|
||||
}
|
||||
|
||||
for k, expectQVals := range expectQ {
|
||||
sort.Strings(expectQVals)
|
||||
actualQVals := actualQ[k]
|
||||
sort.Strings(actualQVals)
|
||||
if !equal(t, expectQVals, actualQVals, msgAndArgs...) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertJSON verifies that the expect json string matches the actual.
|
||||
func AssertJSON(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectVal := map[string]interface{}{}
|
||||
if err := json.Unmarshal([]byte(expect), &expectVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected JSON", err, msgAndArgs...))
|
||||
return false
|
||||
}
|
||||
|
||||
actualVal := map[string]interface{}{}
|
||||
if err := json.Unmarshal([]byte(actual), &actualVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual JSON", err, msgAndArgs...))
|
||||
return false
|
||||
}
|
||||
|
||||
return equal(t, expectVal, actualVal, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertXML verifies that the expect xml string matches the actual.
|
||||
func AssertXML(t *testing.T, expect, actual string, container interface{}, msgAndArgs ...interface{}) bool {
|
||||
expectVal := container
|
||||
if err := xml.Unmarshal([]byte(expect), &expectVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected XML", err, msgAndArgs...))
|
||||
}
|
||||
|
||||
actualVal := container
|
||||
if err := xml.Unmarshal([]byte(actual), &actualVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual XML", err, msgAndArgs...))
|
||||
}
|
||||
return equal(t, expectVal, actualVal, msgAndArgs...)
|
||||
}
|
||||
|
||||
// DidPanic returns if the function paniced and returns true if the function paniced.
|
||||
func DidPanic(fn func()) (bool, interface{}) {
|
||||
var paniced bool
|
||||
var msg interface{}
|
||||
func() {
|
||||
defer func() {
|
||||
if msg = recover(); msg != nil {
|
||||
paniced = true
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}()
|
||||
|
||||
return paniced, msg
|
||||
}
|
||||
|
||||
// objectsAreEqual determines if two objects are considered equal.
|
||||
//
|
||||
// This function does no assertion of any kind.
|
||||
//
|
||||
// Based on github.com/stretchr/testify/assert.ObjectsAreEqual
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func objectsAreEqual(expected, actual interface{}) bool {
|
||||
if expected == nil || actual == nil {
|
||||
return expected == actual
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(expected, actual)
|
||||
}
|
||||
|
||||
// Equal asserts that two objects are equal.
|
||||
//
|
||||
// assert.Equal(t, 123, 123, "123 and 123 should be equal")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
//
|
||||
// Based on github.com/stretchr/testify/assert.Equal
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func equal(t *testing.T, expected, actual interface{}, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
if !objectsAreEqual(expected, actual) {
|
||||
t.Errorf("%s\n%s", messageFromMsgAndArgs(msgAndArgs),
|
||||
SprintExpectActual(expected, actual))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func errMsg(baseMsg string, err error, msgAndArgs ...interface{}) string {
|
||||
message := messageFromMsgAndArgs(msgAndArgs)
|
||||
if message != "" {
|
||||
message += ", "
|
||||
}
|
||||
return fmt.Sprintf("%s%s, %v", message, baseMsg, err)
|
||||
}
|
||||
|
||||
// Based on github.com/stretchr/testify/assert.messageFromMsgAndArgs
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func messageFromMsgAndArgs(msgAndArgs []interface{}) string {
|
||||
if len(msgAndArgs) == 0 || msgAndArgs == nil {
|
||||
return ""
|
||||
}
|
||||
if len(msgAndArgs) == 1 {
|
||||
return msgAndArgs[0].(string)
|
||||
}
|
||||
if len(msgAndArgs) > 1 {
|
||||
return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func queryValueKeys(v url.Values) []string {
|
||||
keys := make([]string, 0, len(v))
|
||||
for k := range v {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// SprintExpectActual returns a string for test failure cases when the actual
|
||||
// value is not the same as the expected.
|
||||
func SprintExpectActual(expect, actual interface{}) string {
|
||||
return fmt.Sprintf("expect: %+v\nactual: %+v\n", expect, actual)
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package awstesting_test
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"testing"
|
||||
|
||||
"github.com/versity/versitygw/aws/internal/awstesting"
|
||||
)
|
||||
|
||||
func TestAssertJSON(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
}{
|
||||
{
|
||||
e: `{"RecursiveStruct":{"RecursiveMap":{"foo":{"NoRecurse":"foo"},"bar":{"NoRecurse":"bar"}}}}`,
|
||||
a: `{"RecursiveStruct":{"RecursiveMap":{"bar":{"NoRecurse":"bar"},"foo":{"NoRecurse":"foo"}}}}`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
mockT := &testing.T{}
|
||||
if awstesting.AssertJSON(mockT, c.e, c.a) != c.asserts {
|
||||
t.Error("Assert JSON result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertXML(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
container struct {
|
||||
XMLName xml.Name `xml:"OperationRequest"`
|
||||
NS string `xml:"xmlns,attr"`
|
||||
RecursiveStruct struct {
|
||||
RecursiveMap struct {
|
||||
Entries []struct {
|
||||
XMLName xml.Name `xml:"entries"`
|
||||
Key string `xml:"key"`
|
||||
Value struct {
|
||||
XMLName xml.Name `xml:"value"`
|
||||
NoRecurse string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}{
|
||||
{
|
||||
e: `<OperationRequest xmlns="https://foo/"><RecursiveStruct xmlns="https://foo/"><RecursiveMap xmlns="https://foo/"><entry xmlns="https://foo/"><key xmlns="https://foo/">foo</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">foo</NoRecurse></value></entry><entry xmlns="https://foo/"><key xmlns="https://foo/">bar</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">bar</NoRecurse></value></entry></RecursiveMap></RecursiveStruct></OperationRequest>`,
|
||||
a: `<OperationRequest xmlns="https://foo/"><RecursiveStruct xmlns="https://foo/"><RecursiveMap xmlns="https://foo/"><entry xmlns="https://foo/"><key xmlns="https://foo/">bar</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">bar</NoRecurse></value></entry><entry xmlns="https://foo/"><key xmlns="https://foo/">foo</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">foo</NoRecurse></value></entry></RecursiveMap></RecursiveStruct></OperationRequest>`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
// mockT := &testing.T{}
|
||||
if awstesting.AssertXML(t, c.e, c.a, c.container) != c.asserts {
|
||||
t.Error("Assert XML result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
}{
|
||||
{
|
||||
e: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
a: `Action=OperationName&Version=2014-01-01&Foo=val2&Bar=val3`,
|
||||
asserts: false,
|
||||
},
|
||||
{
|
||||
e: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
a: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
mockT := &testing.T{}
|
||||
if awstesting.AssertQuery(mockT, c.e, c.a) != c.asserts {
|
||||
t.Error("Assert Query result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// TLSBundleCA is the CA PEM
|
||||
TLSBundleCA []byte
|
||||
|
||||
// TLSBundleCert is the Server PEM
|
||||
TLSBundleCert []byte
|
||||
|
||||
// TLSBundleKey is the Server private key PEM
|
||||
TLSBundleKey []byte
|
||||
|
||||
// ClientTLSCert is the Client PEM
|
||||
ClientTLSCert []byte
|
||||
|
||||
// ClientTLSKey is the Client private key PEM
|
||||
ClientTLSKey []byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
caPEM, _, caCert, caPrivKey, err := generateRootCA()
|
||||
if err != nil {
|
||||
panic("failed to generate testing root CA, " + err.Error())
|
||||
}
|
||||
TLSBundleCA = caPEM
|
||||
|
||||
serverCertPEM, serverCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
|
||||
if err != nil {
|
||||
panic("failed to generate testing server cert, " + err.Error())
|
||||
}
|
||||
TLSBundleCert = serverCertPEM
|
||||
TLSBundleKey = serverCertPrivKeyPEM
|
||||
|
||||
clientCertPEM, clientCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
|
||||
if err != nil {
|
||||
panic("failed to generate testing client cert, " + err.Error())
|
||||
}
|
||||
ClientTLSCert = clientCertPEM
|
||||
ClientTLSKey = clientCertPrivKeyPEM
|
||||
}
|
||||
|
||||
func generateRootCA() (
|
||||
caPEM, caPrivKeyPEM []byte, caCert *x509.Certificate, caPrivKey *rsa.PrivateKey, err error,
|
||||
) {
|
||||
caCert = &x509.Certificate{
|
||||
SerialNumber: big.NewInt(42),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"US"},
|
||||
Organization: []string{"AWS SDK for Go Test Certificate"},
|
||||
CommonName: "Test Root CA",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
// Create CA private and public key
|
||||
caPrivKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed generate CA RSA key, %w", err)
|
||||
}
|
||||
|
||||
// Create CA certificate
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivKey.PublicKey, caPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed generate CA certificate, %w", err)
|
||||
}
|
||||
|
||||
// PEM encode CA certificate and private key
|
||||
var caPEMBuf bytes.Buffer
|
||||
pem.Encode(&caPEMBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caBytes,
|
||||
})
|
||||
|
||||
var caPrivKeyPEMBuf bytes.Buffer
|
||||
pem.Encode(&caPrivKeyPEMBuf, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
|
||||
})
|
||||
|
||||
return caPEMBuf.Bytes(), caPrivKeyPEMBuf.Bytes(), caCert, caPrivKey, nil
|
||||
}
|
||||
|
||||
func generateLocalCert(parentCert *x509.Certificate, parentPrivKey *rsa.PrivateKey) (
|
||||
certPEM, certPrivKeyPEM []byte, err error,
|
||||
) {
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(42),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"US"},
|
||||
Organization: []string{"AWS SDK for Go Test Certificate"},
|
||||
CommonName: "Test Root CA",
|
||||
},
|
||||
IPAddresses: []net.IP{
|
||||
net.IPv4(127, 0, 0, 1),
|
||||
net.IPv6loopback,
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
// Create server private and public key
|
||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate server RSA private key, %w", err)
|
||||
}
|
||||
|
||||
// Create server certificate
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, cert, parentCert, &certPrivKey.PublicKey, parentPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate server certificate, %w", err)
|
||||
}
|
||||
|
||||
// PEM encode certificate and private key
|
||||
var certPEMBuf bytes.Buffer
|
||||
pem.Encode(&certPEMBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
})
|
||||
|
||||
var certPrivKeyPEMBuf bytes.Buffer
|
||||
pem.Encode(&certPrivKeyPEMBuf, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||
})
|
||||
|
||||
return certPEMBuf.Bytes(), certPrivKeyPEMBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
// NewTLSClientCertServer creates a new HTTP test server initialize to require
|
||||
// HTTP clients authenticate with TLS client certificates.
|
||||
func NewTLSClientCertServer(handler http.Handler) (*httptest.Server, error) {
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
|
||||
if server.TLS == nil {
|
||||
server.TLS = &tls.Config{}
|
||||
}
|
||||
server.TLS.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
|
||||
if server.TLS.ClientCAs == nil {
|
||||
server.TLS.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
certPem := append(ClientTLSCert, ClientTLSKey...)
|
||||
if ok := server.TLS.ClientCAs.AppendCertsFromPEM(certPem); !ok {
|
||||
return nil, fmt.Errorf("failed to append client certs")
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// CreateClientTLSCertFiles returns a set of temporary files for the client
|
||||
// certificate and key files.
|
||||
func CreateClientTLSCertFiles() (cert, key string, err error) {
|
||||
cert, err = createTmpFile(ClientTLSCert)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
key, err = createTmpFile(ClientTLSKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return cert, key, nil
|
||||
}
|
||||
|
||||
func availableLocalAddr(ip string) (v string, err error) {
|
||||
l, err := net.Listen("tcp", ip+":0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := l.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
} else if closeErr != nil {
|
||||
err = fmt.Errorf("ip listener close error: %v, original error: %w", closeErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
return l.Addr().String(), nil
|
||||
}
|
||||
|
||||
// CreateTLSServer will create the TLS server on an open port using the
|
||||
// certificate and key. The address will be returned that the server is running on.
|
||||
func CreateTLSServer(cert, key string, mux *http.ServeMux) (string, error) {
|
||||
addr, err := availableLocalAddr("127.0.0.1")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if mux == nil {
|
||||
mux = http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := http.ListenAndServeTLS(addr, cert, key, mux); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 60; i++ {
|
||||
if _, err := http.Get("https://" + addr); err != nil && !strings.Contains(err.Error(), "connection refused") {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
return "https://" + addr, nil
|
||||
}
|
||||
|
||||
// CreateTLSBundleFiles returns the temporary filenames for the certificate
|
||||
// key, and CA PEM content. These files should be deleted when no longer
|
||||
// needed. CleanupTLSBundleFiles can be used for this cleanup.
|
||||
func CreateTLSBundleFiles() (cert, key, ca string, err error) {
|
||||
cert, err = createTmpFile(TLSBundleCert)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
key, err = createTmpFile(TLSBundleKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
ca, err = createTmpFile(TLSBundleCA)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return cert, key, ca, nil
|
||||
}
|
||||
|
||||
// CleanupTLSBundleFiles takes variadic list of files to be deleted.
|
||||
func CleanupTLSBundleFiles(files ...string) error {
|
||||
for _, file := range files {
|
||||
if err := os.Remove(file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTmpFile(b []byte) (string, error) {
|
||||
bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = bundleFile.Write(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer bundleFile.Close()
|
||||
return bundleFile.Name(), nil
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package awstesting
|
||||
|
||||
// DiscardAt is an io.WriteAt that discards
|
||||
// the requested bytes to be written
|
||||
type DiscardAt struct{}
|
||||
|
||||
// WriteAt discards the given []byte slice and returns len(p) bytes
|
||||
// as having been written at the given offset. It will never return an error.
|
||||
func (d DiscardAt) WriteAt(p []byte, off int64) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package awstesting
|
||||
|
||||
// EndlessReader is an io.Reader that will always return
|
||||
// that bytes have been read.
|
||||
type EndlessReader struct{}
|
||||
|
||||
// Read will report that it has read len(p) bytes in p.
|
||||
// The content in the []byte will be unmodified.
|
||||
// This will never return an error.
|
||||
func (e EndlessReader) Read(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
# Based on docker-library's golang 1.6 alpine and wheezy docker files.
|
||||
# https://github.com/docker-library/golang/blob/master/1.6/alpine/Dockerfile
|
||||
# https://github.com/docker-library/golang/blob/master/1.6/wheezy/Dockerfile
|
||||
FROM buildpack-deps:buster-scm
|
||||
|
||||
ENV GOLANG_SRC_REPO_URL https://github.com/golang/go
|
||||
|
||||
# as of 1.20 Go 1.17 is required to bootstrap
|
||||
# see https://github.com/golang/go/issues/44505
|
||||
ENV GOLANG_BOOTSTRAP_URL https://go.dev/dl/go1.17.13.linux-amd64.tar.gz
|
||||
ENV GOLANG_BOOTSTRAP_SHA256 4cdd2bc664724dc7db94ad51b503512c5ae7220951cac568120f64f8e94399fc
|
||||
ENV GOLANG_BOOTSTRAP_PATH /usr/local/bootstrap
|
||||
|
||||
# gcc for cgo
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
g++ \
|
||||
gcc \
|
||||
libc6-dev \
|
||||
make \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Setup the Bootstrap
|
||||
RUN mkdir -p "$GOLANG_BOOTSTRAP_PATH" \
|
||||
&& curl -fsSL "$GOLANG_BOOTSTRAP_URL" -o golang.tar.gz \
|
||||
&& echo "$GOLANG_BOOTSTRAP_SHA256 golang.tar.gz" | sha256sum -c - \
|
||||
&& tar -C "$GOLANG_BOOTSTRAP_PATH" -xzf golang.tar.gz \
|
||||
&& rm golang.tar.gz
|
||||
|
||||
# Get and build Go tip
|
||||
RUN export GOROOT_BOOTSTRAP=$GOLANG_BOOTSTRAP_PATH/go \
|
||||
&& git clone "$GOLANG_SRC_REPO_URL" /usr/local/go \
|
||||
&& cd /usr/local/go/src \
|
||||
&& ./make.bash \
|
||||
&& rm -rf "$GOLANG_BOOTSTRAP_PATH" /usr/local/go/pkg/bootstrap
|
||||
|
||||
# Build Go workspace and environment
|
||||
ENV GOPATH /go
|
||||
ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH
|
||||
RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" \
|
||||
&& chmod -R 777 "$GOPATH"
|
||||
|
||||
WORKDIR $GOPATH
|
||||
@@ -1,16 +0,0 @@
|
||||
FROM aws-golang:tip
|
||||
|
||||
ENV GOPROXY=direct
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& wget -O- https://apt.corretto.aws/corretto.key | apt-key add - \
|
||||
&& add-apt-repository 'deb https://apt.corretto.aws stable main' \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
vim \
|
||||
java-17-amazon-corretto-jdk \
|
||||
&& rm -rf /var/list/apt/lists/*
|
||||
|
||||
ADD . /go/src/github.com/aws/aws-sdk-go-v2
|
||||
WORKDIR /go/src/github.com/aws/aws-sdk-go-v2
|
||||
CMD ["make", "unit"]
|
||||
@@ -1,18 +0,0 @@
|
||||
ARG GO_VERSION
|
||||
FROM golang:${GO_VERSION}
|
||||
|
||||
ENV GOPROXY=direct
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& wget -O- https://apt.corretto.aws/corretto.key | apt-key add - \
|
||||
&& add-apt-repository 'deb https://apt.corretto.aws stable main' \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
vim \
|
||||
java-17-amazon-corretto-jdk \
|
||||
&& rm -rf /var/list/apt/lists/*
|
||||
|
||||
ADD . /go/src/github.com/aws/aws-sdk-go-v2
|
||||
|
||||
WORKDIR /go/src/github.com/aws/aws-sdk-go-v2
|
||||
CMD ["make", "unit"]
|
||||
@@ -1,201 +0,0 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
// ZeroReader is a io.Reader which will always write zeros to the byte slice provided.
|
||||
type ZeroReader struct{}
|
||||
|
||||
// Read fills the provided byte slice with zeros returning the number of bytes written.
|
||||
func (r *ZeroReader) Read(b []byte) (int, error) {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = 0
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// ReadCloser is a io.ReadCloser for unit testing.
|
||||
// Designed to test for leaks and whether a handle has
|
||||
// been closed
|
||||
type ReadCloser struct {
|
||||
Size int
|
||||
Closed bool
|
||||
set bool
|
||||
FillData func(bool, []byte, int, int)
|
||||
}
|
||||
|
||||
// Read will call FillData and fill it with whatever data needed.
|
||||
// Decrements the size until zero, then return io.EOF.
|
||||
func (r *ReadCloser) Read(b []byte) (int, error) {
|
||||
if r.Closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
delta := len(b)
|
||||
if delta > r.Size {
|
||||
delta = r.Size
|
||||
}
|
||||
r.Size -= delta
|
||||
|
||||
for i := 0; i < delta; i++ {
|
||||
b[i] = 'a'
|
||||
}
|
||||
|
||||
if r.FillData != nil {
|
||||
r.FillData(r.set, b, r.Size, delta)
|
||||
}
|
||||
r.set = true
|
||||
|
||||
if r.Size > 0 {
|
||||
return delta, nil
|
||||
}
|
||||
return delta, io.EOF
|
||||
}
|
||||
|
||||
// Close sets Closed to true and returns no error
|
||||
func (r *ReadCloser) Close() error {
|
||||
r.Closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// A FakeContext provides a simple stub implementation of a Context
|
||||
type FakeContext struct {
|
||||
Error error
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// Deadline always will return not set
|
||||
func (c *FakeContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Done returns a read channel for listening to the Done event
|
||||
func (c *FakeContext) Done() <-chan struct{} {
|
||||
return c.DoneCh
|
||||
}
|
||||
|
||||
// Err returns the error, is nil if not set.
|
||||
func (c *FakeContext) Err() error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
// Value ignores the Value and always returns nil
|
||||
func (c *FakeContext) Value(key interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StashEnv stashes the current environment variables except variables listed in envToKeepx
|
||||
// Returns an function to pop out old environment
|
||||
func StashEnv(envToKeep ...string) []string {
|
||||
if runtime.GOOS == "windows" {
|
||||
envToKeep = append(envToKeep, "ComSpec")
|
||||
envToKeep = append(envToKeep, "SYSTEM32")
|
||||
envToKeep = append(envToKeep, "SYSTEMROOT")
|
||||
}
|
||||
envToKeep = append(envToKeep, "PATH", "HOME", "USERPROFILE")
|
||||
extraEnv := getEnvs(envToKeep)
|
||||
originalEnv := os.Environ()
|
||||
os.Clearenv() // clear env
|
||||
for key, val := range extraEnv {
|
||||
os.Setenv(key, val)
|
||||
}
|
||||
return originalEnv
|
||||
}
|
||||
|
||||
// PopEnv takes the list of the environment values and injects them into the
|
||||
// process's environment variable data. Clears any existing environment values
|
||||
// that may already exist.
|
||||
func PopEnv(env []string) {
|
||||
os.Clearenv()
|
||||
|
||||
for _, e := range env {
|
||||
p := strings.SplitN(e, "=", 2)
|
||||
k, v := p[0], ""
|
||||
if len(p) > 1 {
|
||||
v = p[1]
|
||||
}
|
||||
os.Setenv(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// MockCredentialsProvider is a type that can be used to mock out credentials
|
||||
// providers
|
||||
type MockCredentialsProvider struct {
|
||||
RetrieveFn func(ctx context.Context) (aws.Credentials, error)
|
||||
InvalidateFn func()
|
||||
}
|
||||
|
||||
// Retrieve calls the RetrieveFn
|
||||
func (p MockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
|
||||
return p.RetrieveFn(ctx)
|
||||
}
|
||||
|
||||
// Invalidate calls the InvalidateFn
|
||||
func (p MockCredentialsProvider) Invalidate() {
|
||||
p.InvalidateFn()
|
||||
}
|
||||
|
||||
func getEnvs(envs []string) map[string]string {
|
||||
extraEnvs := make(map[string]string)
|
||||
for _, env := range envs {
|
||||
if val, ok := os.LookupEnv(env); ok && len(val) > 0 {
|
||||
extraEnvs[env] = val
|
||||
}
|
||||
}
|
||||
return extraEnvs
|
||||
}
|
||||
|
||||
const (
|
||||
signaturePreambleSigV4 = "AWS4-HMAC-SHA256"
|
||||
signaturePreambleSigV4A = "AWS4-ECDSA-P256-SHA256"
|
||||
)
|
||||
|
||||
// SigV4Signature represents a parsed sigv4 or sigv4a signature.
|
||||
type SigV4Signature struct {
|
||||
Preamble string // e.g. AWS4-HMAC-SHA256, AWS4-ECDSA-P256-SHA256
|
||||
SigningName string // generally the service name e.g. "s3"
|
||||
SigningRegion string // for sigv4a this is the region-set header as-is
|
||||
SignedHeaders []string // list of signed headers
|
||||
Signature string // calculated signature
|
||||
}
|
||||
|
||||
// ParseSigV4Signature deconstructs a sigv4 or sigv4a signature from a set of
|
||||
// request headers.
|
||||
func ParseSigV4Signature(header http.Header) *SigV4Signature {
|
||||
auth := header.Get("Authorization")
|
||||
|
||||
preamble, after, _ := strings.Cut(auth, " ")
|
||||
credential, after, _ := strings.Cut(after, ", ")
|
||||
signedHeaders, signature, _ := strings.Cut(after, ", ")
|
||||
|
||||
credentialParts := strings.Split(credential, "/")
|
||||
|
||||
// sigv4 : AccessKeyID/DateString/SigningRegion/SigningName/SignatureID
|
||||
// sigv4a : AccessKeyID/DateString/SigningName/SignatureID, region set on
|
||||
// header
|
||||
var signingName, signingRegion string
|
||||
if preamble == signaturePreambleSigV4 {
|
||||
signingName = credentialParts[3]
|
||||
signingRegion = credentialParts[2]
|
||||
} else if preamble == signaturePreambleSigV4A {
|
||||
signingName = credentialParts[2]
|
||||
signingRegion = header.Get("X-Amz-Region-Set")
|
||||
}
|
||||
|
||||
return &SigV4Signature{
|
||||
Preamble: preamble,
|
||||
SigningName: signingName,
|
||||
SigningRegion: signingRegion,
|
||||
SignedHeaders: strings.Split(signedHeaders, ";"),
|
||||
Signature: signature,
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package awstesting_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/versity/versitygw/aws/internal/awstesting"
|
||||
)
|
||||
|
||||
func TestReadCloserClose(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 1}
|
||||
err := rc.Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
if !rc.Closed {
|
||||
t.Errorf("expect closed, was not")
|
||||
}
|
||||
if e, a := rc.Size, 1; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCloserRead(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 5}
|
||||
b := make([]byte, 2)
|
||||
|
||||
n, err := rc.Read(b)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
if e, a := n, 2; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if rc.Closed {
|
||||
t.Errorf("expect not to be closed")
|
||||
}
|
||||
if e, a := rc.Size, 3; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
err = rc.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
n, err = rc.Read(b)
|
||||
if e, a := err, io.EOF; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := n, 0; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCloserReadAll(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 5}
|
||||
b := make([]byte, 5)
|
||||
|
||||
n, err := rc.Read(b)
|
||||
|
||||
if e, a := err, io.EOF; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := n, 5; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if rc.Closed {
|
||||
t.Errorf("expect not to be closed")
|
||||
}
|
||||
if e, a := rc.Size, 0; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package sdk
|
||||
|
||||
// Invalidator provides access to a type's invalidate method to make it
|
||||
// invalidate it cache.
|
||||
//
|
||||
// e.g aws.SafeCredentialsProvider's Invalidate method.
|
||||
type Invalidator interface {
|
||||
Invalidate()
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
NowTime = time.Now
|
||||
Sleep = time.Sleep
|
||||
SleepWithContext = sleepWithContext
|
||||
}
|
||||
|
||||
// NowTime is a value for getting the current time. This value can be overridden
|
||||
// for testing mocking out current time.
|
||||
var NowTime func() time.Time
|
||||
|
||||
// Sleep is a value for sleeping for a duration. This value can be overridden
|
||||
// for testing and mocking out sleep duration.
|
||||
var Sleep func(time.Duration)
|
||||
|
||||
// SleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the Context's
|
||||
// error will be returned.
|
||||
//
|
||||
// This value can be overridden for testing and mocking out sleep duration.
|
||||
var SleepWithContext func(context.Context, time.Duration) error
|
||||
|
||||
// sleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the
|
||||
// Context's error will be returned.
|
||||
func sleepWithContext(ctx context.Context, dur time.Duration) error {
|
||||
t := time.NewTimer(dur)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-t.C:
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// noOpSleepWithContext does nothing, returns immediately.
|
||||
func noOpSleepWithContext(context.Context, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func noOpSleep(time.Duration) {}
|
||||
|
||||
// TestingUseNopSleep is a utility for disabling sleep across the SDK for
|
||||
// testing.
|
||||
func TestingUseNopSleep() func() {
|
||||
SleepWithContext = noOpSleepWithContext
|
||||
Sleep = noOpSleep
|
||||
|
||||
return func() {
|
||||
SleepWithContext = sleepWithContext
|
||||
Sleep = time.Sleep
|
||||
}
|
||||
}
|
||||
|
||||
// TestingUseReferenceTime is a utility for swapping the time function across the SDK to return a specific reference time
|
||||
// for testing purposes.
|
||||
func TestingUseReferenceTime(referenceTime time.Time) func() {
|
||||
NowTime = func() time.Time {
|
||||
return referenceTime
|
||||
}
|
||||
return func() {
|
||||
NowTime = time.Now
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSleepWithContext(t *testing.T) {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
defer cancelFn()
|
||||
|
||||
err := sleepWithContext(ctx, 1*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Errorf("expect context to not be canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSleepWithContext_Canceled(t *testing.T) {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
cancelFn()
|
||||
|
||||
err := sleepWithContext(ctx, 10*time.Second)
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, did not get one")
|
||||
}
|
||||
|
||||
if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v error, got %v", e, a)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package strings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
|
||||
// under Unicode case-folding.
|
||||
func HasPrefixFold(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package strings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHasPrefixFold(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
prefix string
|
||||
}
|
||||
tests := map[string]struct {
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
"empty strings and prefix": {
|
||||
args: args{
|
||||
s: "",
|
||||
prefix: "",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"strings starts with prefix": {
|
||||
args: args{
|
||||
s: "some string",
|
||||
prefix: "some",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"prefix longer then string": {
|
||||
args: args{
|
||||
s: "some",
|
||||
prefix: "some string",
|
||||
},
|
||||
},
|
||||
"equal length string and prefix": {
|
||||
args: args{
|
||||
s: "short string",
|
||||
prefix: "short string",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"different cases": {
|
||||
args: args{
|
||||
s: "ShOrT StRING",
|
||||
prefix: "short",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"empty prefix not empty string": {
|
||||
args: args{
|
||||
s: "ShOrT StRING",
|
||||
prefix: "",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"mixed-case prefixes": {
|
||||
args: args{
|
||||
s: "SoMe String",
|
||||
prefix: "sOme",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if got := HasPrefixFold(tt.args.s, tt.args.prefix); got != tt.want {
|
||||
t.Errorf("HasPrefixFold() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHasPrefixFold(b *testing.B) {
|
||||
HasPrefixFold("SoME string", "sOmE")
|
||||
}
|
||||
|
||||
func BenchmarkHasPrefix(b *testing.B) {
|
||||
strings.HasPrefix(strings.ToLower("SoME string"), strings.ToLower("sOmE"))
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
sdkstrings "github.com/versity/versitygw/aws/internal/strings"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Rules houses a set of Rule needed for validation of a
|
||||
@@ -61,7 +61,7 @@ type Patterns []string
|
||||
// been found
|
||||
func (p Patterns) IsValid(value string) bool {
|
||||
for _, pattern := range p {
|
||||
if sdkstrings.HasPrefixFold(value, pattern) {
|
||||
if hasPrefixFold(value, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -80,3 +80,9 @@ func (r InclusiveRules) IsValid(value string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// hasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
|
||||
// under Unicode case-folding.
|
||||
func hasPrefixFold(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
|
||||
}
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
internalauth "github.com/versity/versitygw/aws/internal/auth"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
const computePayloadHashMiddlewareID = "ComputePayloadHash"
|
||||
|
||||
// HashComputationError indicates an error occurred while computing the signing hash
|
||||
type HashComputationError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error is the error message
|
||||
func (e *HashComputationError) Error() string {
|
||||
return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error if one is set
|
||||
func (e *HashComputationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// SigningError indicates an error condition occurred while performing SigV4 signing
|
||||
type SigningError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SigningError) Error() string {
|
||||
return fmt.Sprintf("failed to sign request: %v", e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error cause
|
||||
func (e *SigningError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// UseDynamicPayloadSigningMiddleware swaps the compute payload sha256 middleware with a resolver middleware that
|
||||
// switches between unsigned and signed payload based on TLS state for request.
|
||||
// This middleware should not be used for AWS APIs that do not support unsigned payload signing auth.
|
||||
// By default, SDK uses this middleware for known AWS APIs that support such TLS based auth selection .
|
||||
//
|
||||
// Usage example -
|
||||
// S3 PutObject API allows unsigned payload signing auth usage when TLS is enabled, and uses this middleware to
|
||||
// dynamically switch between unsigned and signed payload based on TLS state for request.
|
||||
func UseDynamicPayloadSigningMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Swap(computePayloadHashMiddlewareID, &dynamicPayloadSigningMiddleware{})
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamicPayloadSigningMiddleware dynamically resolves the middleware that computes and set payload sha256 middleware.
|
||||
type dynamicPayloadSigningMiddleware struct {
|
||||
}
|
||||
|
||||
// ID returns the resolver identifier
|
||||
func (m *dynamicPayloadSigningMiddleware) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize delegates SHA256 computation according to whether the request
|
||||
// is TLS-enabled.
|
||||
func (m *dynamicPayloadSigningMiddleware) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
|
||||
}
|
||||
|
||||
if req.IsHTTPS() {
|
||||
return (&UnsignedPayload{}).HandleFinalize(ctx, in, next)
|
||||
}
|
||||
return (&ComputePayloadSHA256{}).HandleFinalize(ctx, in, next)
|
||||
}
|
||||
|
||||
// UnsignedPayload sets the SigV4 request payload hash to unsigned.
|
||||
//
|
||||
// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
|
||||
// stored in the context. (e.g. application pre-computed SHA256 before making
|
||||
// API call).
|
||||
//
|
||||
// This middleware does not check the X-Amz-Content-Sha256 header, if that
|
||||
// header is serialized a middleware must translate it into the context.
|
||||
type UnsignedPayload struct{}
|
||||
|
||||
// AddUnsignedPayloadMiddleware adds unsignedPayload to the operation
|
||||
// middleware stack
|
||||
func AddUnsignedPayloadMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&UnsignedPayload{}, "ResolveEndpointV2", middleware.After)
|
||||
}
|
||||
|
||||
// ID returns the unsignedPayload identifier
|
||||
func (m *UnsignedPayload) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize sets the payload hash magic value to the unsigned sentinel.
|
||||
func (m *UnsignedPayload) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if GetPayloadHash(ctx) == "" {
|
||||
ctx = SetPayloadHash(ctx, v4Internal.UnsignedPayload)
|
||||
}
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// ComputePayloadSHA256 computes SHA256 payload hash to sign.
|
||||
//
|
||||
// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
|
||||
// stored in the context. (e.g. application pre-computed SHA256 before making
|
||||
// API call).
|
||||
//
|
||||
// This middleware does not check the X-Amz-Content-Sha256 header, if that
|
||||
// header is serialized a middleware must translate it into the context.
|
||||
type ComputePayloadSHA256 struct{}
|
||||
|
||||
// AddComputePayloadSHA256Middleware adds computePayloadSHA256 to the
|
||||
// operation middleware stack
|
||||
func AddComputePayloadSHA256Middleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&ComputePayloadSHA256{}, "ResolveEndpointV2", middleware.After)
|
||||
}
|
||||
|
||||
// RemoveComputePayloadSHA256Middleware removes computePayloadSHA256 from the
|
||||
// operation middleware stack
|
||||
func RemoveComputePayloadSHA256Middleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Remove(computePayloadHashMiddlewareID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ID is the middleware name
|
||||
func (m *ComputePayloadSHA256) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize computes the payload hash for the request, storing it to the
|
||||
// context. This is a no-op if a caller has previously set that value.
|
||||
func (m *ComputePayloadSHA256) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if GetPayloadHash(ctx) != "" {
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
|
||||
}
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
if stream := req.GetStream(); stream != nil {
|
||||
_, err = io.Copy(hash, stream)
|
||||
if err != nil {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("failed to compute payload hash, %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
if err := req.RewindStream(); err != nil {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("failed to seek body to start, %w", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// SwapComputePayloadSHA256ForUnsignedPayloadMiddleware replaces the
|
||||
// ComputePayloadSHA256 middleware with the UnsignedPayload middleware.
|
||||
//
|
||||
// Use this to disable computing the Payload SHA256 checksum and instead use
|
||||
// UNSIGNED-PAYLOAD for the SHA256 value.
|
||||
func SwapComputePayloadSHA256ForUnsignedPayloadMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Swap(computePayloadHashMiddlewareID, &UnsignedPayload{})
|
||||
return err
|
||||
}
|
||||
|
||||
// ContentSHA256Header sets the X-Amz-Content-Sha256 header value to
|
||||
// the Payload hash stored in the context.
|
||||
type ContentSHA256Header struct{}
|
||||
|
||||
// AddContentSHA256HeaderMiddleware adds ContentSHA256Header to the
|
||||
// operation middleware stack
|
||||
func AddContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&ContentSHA256Header{}, computePayloadHashMiddlewareID, middleware.After)
|
||||
}
|
||||
|
||||
// RemoveContentSHA256HeaderMiddleware removes contentSHA256Header middleware
|
||||
// from the operation middleware stack
|
||||
func RemoveContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Remove((*ContentSHA256Header)(nil).ID())
|
||||
return err
|
||||
}
|
||||
|
||||
// ID returns the ContentSHA256HeaderMiddleware identifier
|
||||
func (m *ContentSHA256Header) ID() string {
|
||||
return "SigV4ContentSHA256Header"
|
||||
}
|
||||
|
||||
// HandleFinalize sets the X-Amz-Content-Sha256 header value to the Payload hash
|
||||
// stored in the context.
|
||||
func (m *ContentSHA256Header) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
|
||||
}
|
||||
|
||||
req.Header.Set(v4Internal.ContentSHAKey, GetPayloadHash(ctx))
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// SignHTTPRequestMiddlewareOptions is the configuration options for
|
||||
// [SignHTTPRequestMiddleware].
|
||||
//
|
||||
// Deprecated: [SignHTTPRequestMiddleware] is deprecated.
|
||||
type SignHTTPRequestMiddlewareOptions struct {
|
||||
CredentialsProvider aws.CredentialsProvider
|
||||
Signer HTTPSigner
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4
|
||||
// HTTP Signing.
|
||||
//
|
||||
// Deprecated: AWS service clients no longer use this middleware. Signing as an
|
||||
// SDK operation is now performed through an internal per-service middleware
|
||||
// which opaquely selects and uses the signer from the resolved auth scheme.
|
||||
type SignHTTPRequestMiddleware struct {
|
||||
credentialsProvider aws.CredentialsProvider
|
||||
signer HTTPSigner
|
||||
logSigning bool
|
||||
}
|
||||
|
||||
// NewSignHTTPRequestMiddleware constructs a [SignHTTPRequestMiddleware] using
|
||||
// the given [Signer] for signing requests.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
|
||||
return &SignHTTPRequestMiddleware{
|
||||
credentialsProvider: options.CredentialsProvider,
|
||||
signer: options.Signer,
|
||||
logSigning: options.LogSigning,
|
||||
}
|
||||
}
|
||||
|
||||
// ID is the SignHTTPRequestMiddleware identifier.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func (s *SignHTTPRequestMiddleware) ID() string {
|
||||
return "Signing"
|
||||
}
|
||||
|
||||
// HandleFinalize will take the provided input and sign the request using the
|
||||
// SigV4 authentication scheme.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if !haveCredentialProvider(s.credentialsProvider) {
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
|
||||
}
|
||||
|
||||
signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
|
||||
payloadHash := GetPayloadHash(ctx)
|
||||
if len(payloadHash) == 0 {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
|
||||
}
|
||||
|
||||
mctx := metrics.Context(ctx)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.CredentialFetchStartTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
credentials, err := s.credentialsProvider.Retrieve(ctx)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.CredentialFetchEndTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
|
||||
}
|
||||
|
||||
signerOptions := []func(o *SignerOptions){
|
||||
func(o *SignerOptions) {
|
||||
o.Logger = middleware.GetLogger(ctx)
|
||||
o.LogSigning = s.logSigning
|
||||
},
|
||||
}
|
||||
|
||||
// existing DisableURIPathEscaping is equivalent in purpose
|
||||
// to authentication scheme property DisableDoubleEncoding
|
||||
disableDoubleEncoding, overridden := internalauth.GetDisableDoubleEncoding(ctx)
|
||||
if overridden {
|
||||
signerOptions = append(signerOptions, func(o *SignerOptions) {
|
||||
o.DisableURIPathEscaping = disableDoubleEncoding
|
||||
})
|
||||
}
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.SignStartTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(), signerOptions...)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.SignEndTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
|
||||
}
|
||||
|
||||
ctx = awsmiddleware.SetSigningCredentials(ctx, credentials)
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// StreamingEventsPayload signs input event stream messages.
|
||||
type StreamingEventsPayload struct{}
|
||||
|
||||
// AddStreamingEventsPayload adds the streamingEventsPayload middleware to the stack.
|
||||
func AddStreamingEventsPayload(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Add(&StreamingEventsPayload{}, middleware.Before)
|
||||
}
|
||||
|
||||
// ID identifies the middleware.
|
||||
func (s *StreamingEventsPayload) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize marks the input stream to be signed with SigV4.
|
||||
func (s *StreamingEventsPayload) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
contentSHA := GetPayloadHash(ctx)
|
||||
if len(contentSHA) == 0 {
|
||||
contentSHA = v4Internal.StreamingEventsPayload
|
||||
}
|
||||
|
||||
ctx = SetPayloadHash(ctx, contentSHA)
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// GetSignedRequestSignature attempts to extract the signature of the request.
|
||||
// Returning an error if the request is unsigned, or unable to extract the
|
||||
// signature.
|
||||
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
|
||||
const authHeaderSignatureElem = "Signature="
|
||||
|
||||
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
|
||||
ps := strings.Split(auth, ", ")
|
||||
for _, p := range ps {
|
||||
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
|
||||
sig := p[len(authHeaderSignatureElem):]
|
||||
if len(sig) == 0 {
|
||||
return nil, fmt.Errorf("invalid request signature authorization header")
|
||||
}
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("request not signed")
|
||||
}
|
||||
|
||||
func haveCredentialProvider(p aws.CredentialsProvider) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !aws.IsCredentialsProvider(p, (*aws.AnonymousCredentials)(nil))
|
||||
}
|
||||
|
||||
type payloadHashKey struct{}
|
||||
|
||||
// GetPayloadHash retrieves the payload hash to use for signing
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetPayloadHash(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, payloadHashKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// SetPayloadHash sets the payload hash to be used for signing the request
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetPayloadHash(ctx context.Context, hash string) context.Context {
|
||||
return middleware.WithStackValue(ctx, payloadHashKey{}, hash)
|
||||
}
|
||||
@@ -1,415 +0,0 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/versity/versitygw/aws/internal/awstesting/unit"
|
||||
)
|
||||
|
||||
func TestComputePayloadHashMiddleware(t *testing.T) {
|
||||
cases := []struct {
|
||||
content io.Reader
|
||||
expectedHash string
|
||||
expectedErr interface{}
|
||||
}{
|
||||
0: {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
expectedHash: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56",
|
||||
},
|
||||
1: {
|
||||
content: func() io.Reader {
|
||||
return &nonSeeker{}
|
||||
}(),
|
||||
expectedErr: &HashComputationError{},
|
||||
},
|
||||
2: {
|
||||
content: func() io.Reader {
|
||||
return &semiSeekable{}
|
||||
}(),
|
||||
expectedErr: &HashComputationError{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range cases {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
c := &ComputePayloadSHA256{}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
value := GetPayloadHash(ctx)
|
||||
if len(value) == 0 {
|
||||
t.Fatalf("expected payload hash value to be on context")
|
||||
}
|
||||
if e, a := tt.expectedHash, value; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
stream, err := smithyhttp.NewStackRequest().(*smithyhttp.Request).SetStream(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
_, _, err = c.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: stream}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type httpSignerFunc func(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error
|
||||
|
||||
func (f httpSignerFunc) SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
|
||||
return f(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
|
||||
}
|
||||
|
||||
func TestSignHTTPRequestMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
creds aws.CredentialsProvider
|
||||
hash string
|
||||
logSigning bool
|
||||
expectedErr interface{}
|
||||
}{
|
||||
"success": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "0123456789abcdef",
|
||||
},
|
||||
"error": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "",
|
||||
expectedErr: &SigningError{},
|
||||
},
|
||||
"anonymous creds": {
|
||||
creds: aws.AnonymousCredentials{},
|
||||
},
|
||||
"nil creds": {
|
||||
creds: nil,
|
||||
},
|
||||
"with log signing": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "0123456789abcdef",
|
||||
logSigning: true,
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
signingName = "serviceId"
|
||||
signingRegion = "regionName"
|
||||
)
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c := &SignHTTPRequestMiddleware{
|
||||
credentialsProvider: tt.creds,
|
||||
signer: httpSignerFunc(
|
||||
func(ctx context.Context,
|
||||
credentials aws.Credentials, r *http.Request, payloadHash string,
|
||||
service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) error {
|
||||
var options SignerOptions
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
if options.Logger == nil {
|
||||
t.Errorf("expect logger, got none")
|
||||
}
|
||||
if options.LogSigning {
|
||||
options.Logger.Logf(logging.Debug, t.Name())
|
||||
}
|
||||
|
||||
expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
|
||||
if e, a := expectCreds, credentials; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := tt.hash, payloadHash; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingName, service; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingRegion, region; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
logSigning: tt.logSigning,
|
||||
}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
ctx := awsmiddleware.SetSigningRegion(
|
||||
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||||
signingRegion)
|
||||
|
||||
var loggerBuf bytes.Buffer
|
||||
logger := logging.NewStandardLogger(&loggerBuf)
|
||||
ctx = middleware.SetLogger(ctx, logger)
|
||||
|
||||
if len(tt.hash) != 0 {
|
||||
ctx = SetPayloadHash(ctx, tt.hash)
|
||||
}
|
||||
|
||||
_, _, err := c.HandleFinalize(ctx, middleware.FinalizeInput{
|
||||
Request: &smithyhttp.Request{Request: &http.Request{}},
|
||||
}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
|
||||
if tt.logSigning {
|
||||
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v logged in %v", e, a)
|
||||
}
|
||||
} else {
|
||||
if loggerBuf.Len() != 0 {
|
||||
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwapComputePayloadSHA256ForUnsignedPayloadMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
InitStep func(*middleware.Stack) error
|
||||
Mutator func(*middleware.Stack) error
|
||||
ExpectErr string
|
||||
ExpectIDs []string
|
||||
}{
|
||||
"swap in place": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddComputePayloadSHA256Middleware(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectIDs: []string{
|
||||
"ResolveEndpointV2",
|
||||
computePayloadHashMiddlewareID, // should snap to after resolve endpoint
|
||||
"before",
|
||||
"after",
|
||||
},
|
||||
},
|
||||
|
||||
"already unsigned payload exists": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddUnsignedPayloadMiddleware(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectIDs: []string{
|
||||
"ResolveEndpointV2",
|
||||
computePayloadHashMiddlewareID,
|
||||
"before",
|
||||
"after",
|
||||
},
|
||||
},
|
||||
|
||||
"no compute payload": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectErr: "not found, " + computePayloadHashMiddlewareID,
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
stack := middleware.NewStack(t.Name(), smithyhttp.NewStackRequest)
|
||||
stack.Finalize.Add(&nopResolveEndpoint{}, middleware.After)
|
||||
if err := c.InitStep(stack); err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
err := c.Mutator(stack)
|
||||
if len(c.ExpectErr) != 0 {
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("expect error to contain %v, got %v", e, a)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(c.ExpectIDs, stack.Finalize.List()); len(diff) != 0 {
|
||||
t.Errorf("expect match\n%v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseDynamicPayloadSigningMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
content io.Reader
|
||||
url string
|
||||
expectedHash string
|
||||
expectedErr interface{}
|
||||
}{
|
||||
"TLS disabled": {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
url: "http://localhost.com/",
|
||||
expectedHash: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56",
|
||||
},
|
||||
"TLS enabled": {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
url: "https://localhost.com/",
|
||||
expectedHash: "UNSIGNED-PAYLOAD",
|
||||
},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c := &dynamicPayloadSigningMiddleware{}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
value := GetPayloadHash(ctx)
|
||||
if len(value) == 0 {
|
||||
t.Fatalf("expected payload hash value to be on context")
|
||||
}
|
||||
if e, a := tt.expectedHash, value; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
|
||||
req.URL, _ = url.Parse(tt.url)
|
||||
stream, err := req.SetStream(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
_, _, err = c.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: stream}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type nonSeeker struct{}
|
||||
|
||||
func (nonSeeker) Read(p []byte) (n int, err error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
type semiSeekable struct {
|
||||
hasSeeked bool
|
||||
}
|
||||
|
||||
func (s *semiSeekable) Seek(offset int64, whence int) (int64, error) {
|
||||
if !s.hasSeeked {
|
||||
s.hasSeeked = true
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("io seek error")
|
||||
}
|
||||
|
||||
func (*semiSeekable) Read(p []byte) (n int, err error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
type nopResolveEndpoint struct{}
|
||||
|
||||
func (*nopResolveEndpoint) ID() string { return "ResolveEndpointV2" }
|
||||
|
||||
func (*nopResolveEndpoint) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
var (
|
||||
_ middleware.FinalizeMiddleware = &UnsignedPayload{}
|
||||
_ middleware.FinalizeMiddleware = &ComputePayloadSHA256{}
|
||||
_ middleware.FinalizeMiddleware = &ContentSHA256Header{}
|
||||
_ middleware.FinalizeMiddleware = &SignHTTPRequestMiddleware{}
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyHTTP "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
)
|
||||
|
||||
// HTTPPresigner is an interface to a SigV4 signer that can sign create a
|
||||
// presigned URL for a HTTP requests.
|
||||
type HTTPPresigner interface {
|
||||
PresignHTTP(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error)
|
||||
}
|
||||
|
||||
// PresignedHTTPRequest provides the URL and signed headers that are included
|
||||
// in the presigned URL.
|
||||
type PresignedHTTPRequest struct {
|
||||
URL string
|
||||
Method string
|
||||
SignedHeader http.Header
|
||||
}
|
||||
|
||||
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
|
||||
type PresignHTTPRequestMiddlewareOptions struct {
|
||||
CredentialsProvider aws.CredentialsProvider
|
||||
Presigner HTTPPresigner
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
|
||||
// presigned URL for an HTTP request.
|
||||
//
|
||||
// Will short circuit the middleware stack and not forward onto the next
|
||||
// Finalize handler.
|
||||
type PresignHTTPRequestMiddleware struct {
|
||||
credentialsProvider aws.CredentialsProvider
|
||||
presigner HTTPPresigner
|
||||
logSigning bool
|
||||
}
|
||||
|
||||
// NewPresignHTTPRequestMiddleware returns a new PresignHTTPRequestMiddleware
|
||||
// initialized with the presigner.
|
||||
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
|
||||
return &PresignHTTPRequestMiddleware{
|
||||
credentialsProvider: options.CredentialsProvider,
|
||||
presigner: options.Presigner,
|
||||
logSigning: options.LogSigning,
|
||||
}
|
||||
}
|
||||
|
||||
// ID provides the middleware ID.
|
||||
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
|
||||
|
||||
// HandleFinalize will take the provided input and create a presigned url for
|
||||
// the http request using the SigV4 presign authentication scheme.
|
||||
//
|
||||
// Since the signed request is not a valid HTTP request
|
||||
func (s *PresignHTTPRequestMiddleware) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyHTTP.Request)
|
||||
if !ok {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
|
||||
}
|
||||
}
|
||||
|
||||
httpReq := req.Build(ctx)
|
||||
if !haveCredentialProvider(s.credentialsProvider) {
|
||||
out.Result = &PresignedHTTPRequest{
|
||||
URL: httpReq.URL.String(),
|
||||
Method: httpReq.Method,
|
||||
SignedHeader: http.Header{},
|
||||
}
|
||||
|
||||
return out, metadata, nil
|
||||
}
|
||||
|
||||
signingName := awsmiddleware.GetSigningName(ctx)
|
||||
signingRegion := awsmiddleware.GetSigningRegion(ctx)
|
||||
payloadHash := GetPayloadHash(ctx)
|
||||
if len(payloadHash) == 0 {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("computed payload hash missing from context"),
|
||||
}
|
||||
}
|
||||
|
||||
credentials, err := s.credentialsProvider.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("failed to retrieve credentials: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
u, h, err := s.presigner.PresignHTTP(ctx, credentials,
|
||||
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(),
|
||||
func(o *SignerOptions) {
|
||||
o.Logger = middleware.GetLogger(ctx)
|
||||
o.LogSigning = s.logSigning
|
||||
})
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("failed to sign http request, %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
out.Result = &PresignedHTTPRequest{
|
||||
URL: u,
|
||||
Method: httpReq.Method,
|
||||
SignedHeader: h,
|
||||
}
|
||||
|
||||
return out, metadata, nil
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/versity/versitygw/aws/internal/awstesting/unit"
|
||||
)
|
||||
|
||||
type httpPresignerFunc func(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error)
|
||||
|
||||
func (f httpPresignerFunc) PresignHTTP(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (
|
||||
url string, signedHeader http.Header, err error,
|
||||
) {
|
||||
return f(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
|
||||
}
|
||||
|
||||
func TestPresignHTTPRequestMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
Request *http.Request
|
||||
Creds aws.CredentialsProvider
|
||||
PayloadHash string
|
||||
LogSigning bool
|
||||
ExpectResult *PresignedHTTPRequest
|
||||
ExpectErr string
|
||||
}{
|
||||
"success": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "0123456789abcdef",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"error": {
|
||||
Request: func() *http.Request {
|
||||
return &http.Request{}
|
||||
}(),
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "",
|
||||
ExpectErr: "failed to sign request",
|
||||
},
|
||||
"anonymous creds": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "",
|
||||
ExpectErr: "failed to sign request",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"nil creds": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: nil,
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"with log signing": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "0123456789abcdef",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
|
||||
LogSigning: true,
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
signingName = "serviceId"
|
||||
signingRegion = "regionName"
|
||||
)
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
m := &PresignHTTPRequestMiddleware{
|
||||
credentialsProvider: c.Creds,
|
||||
|
||||
presigner: httpPresignerFunc(func(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error) {
|
||||
var options SignerOptions
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
if options.Logger == nil {
|
||||
t.Errorf("expect logger, got none")
|
||||
}
|
||||
if options.LogSigning {
|
||||
options.Logger.Logf(logging.Debug, t.Name())
|
||||
}
|
||||
|
||||
if !haveCredentialProvider(c.Creds) {
|
||||
t.Errorf("expect presigner not to be called for not credentials provider")
|
||||
}
|
||||
|
||||
expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
|
||||
if e, a := expectCreds, credentials; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := c.PayloadHash, payloadHash; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingName, service; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingRegion, region; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return c.ExpectResult.URL, c.ExpectResult.SignedHeader, nil
|
||||
}),
|
||||
logSigning: c.LogSigning,
|
||||
}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(
|
||||
func(ctx context.Context, in middleware.FinalizeInput) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
t.Errorf("expect next handler not to be called")
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
ctx := awsmiddleware.SetSigningRegion(
|
||||
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||||
signingRegion)
|
||||
|
||||
var loggerBuf bytes.Buffer
|
||||
logger := logging.NewStandardLogger(&loggerBuf)
|
||||
ctx = middleware.SetLogger(ctx, logger)
|
||||
|
||||
if len(c.PayloadHash) != 0 {
|
||||
ctx = SetPayloadHash(ctx, c.PayloadHash)
|
||||
}
|
||||
|
||||
result, _, err := m.HandleFinalize(ctx, middleware.FinalizeInput{
|
||||
Request: &smithyhttp.Request{
|
||||
Request: c.Request,
|
||||
},
|
||||
}, next)
|
||||
if len(c.ExpectErr) != 0 {
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("expect error to contain %v, got %v", e, a)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(c.ExpectResult, result.Result); len(diff) != 0 {
|
||||
t.Errorf("expect result match\n%v", diff)
|
||||
}
|
||||
|
||||
if c.LogSigning {
|
||||
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v logged in %v", e, a)
|
||||
}
|
||||
} else {
|
||||
if loggerBuf.Len() != 0 {
|
||||
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ middleware.FinalizeMiddleware = &PresignHTTPRequestMiddleware{}
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
// EventStreamSigner is an AWS EventStream protocol signer.
|
||||
type EventStreamSigner interface {
|
||||
GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error)
|
||||
}
|
||||
|
||||
// StreamSignerOptions is the configuration options for StreamSigner.
|
||||
type StreamSignerOptions struct{}
|
||||
|
||||
// StreamSigner implements Signature Version 4 (SigV4) signing of event stream encoded payloads.
|
||||
type StreamSigner struct {
|
||||
options StreamSignerOptions
|
||||
|
||||
credentials aws.Credentials
|
||||
service string
|
||||
region string
|
||||
|
||||
prevSignature []byte
|
||||
|
||||
signingKeyDeriver *v4Internal.SigningKeyDeriver
|
||||
}
|
||||
|
||||
// NewStreamSigner returns a new AWS EventStream protocol signer.
|
||||
func NewStreamSigner(credentials aws.Credentials, service, region string, seedSignature []byte, optFns ...func(*StreamSignerOptions)) *StreamSigner {
|
||||
o := StreamSignerOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&o)
|
||||
}
|
||||
|
||||
return &StreamSigner{
|
||||
options: o,
|
||||
credentials: credentials,
|
||||
service: service,
|
||||
region: region,
|
||||
signingKeyDeriver: v4Internal.NewSigningKeyDeriver(),
|
||||
prevSignature: seedSignature,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSignature signs the provided header and payload bytes.
|
||||
func (s *StreamSigner) GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error) {
|
||||
options := s.options
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
prevSignature := s.prevSignature
|
||||
|
||||
st := v4Internal.NewSigningTime(signingTime)
|
||||
|
||||
sigKey := s.signingKeyDeriver.DeriveKey(s.credentials, s.service, s.region, st)
|
||||
|
||||
scope := v4Internal.BuildCredentialScope(st, s.region, s.service)
|
||||
|
||||
stringToSign := s.buildEventStreamStringToSign(headers, payload, prevSignature, scope, &st)
|
||||
|
||||
signature := v4Internal.HMACSHA256(sigKey, []byte(stringToSign))
|
||||
s.prevSignature = signature
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (s *StreamSigner) buildEventStreamStringToSign(headers, payload, previousSignature []byte, credentialScope string, signingTime *v4Internal.SigningTime) string {
|
||||
hash := sha256.New()
|
||||
return strings.Join([]string{
|
||||
"AWS4-HMAC-SHA256-PAYLOAD",
|
||||
signingTime.TimeFormat(),
|
||||
credentialScope,
|
||||
hex.EncodeToString(previousSignature),
|
||||
hex.EncodeToString(makeHash(hash, headers)),
|
||||
hex.EncodeToString(makeHash(hash, payload)),
|
||||
}, "\n")
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -208,7 +207,7 @@ func TestBuildCanonicalRequest(t *testing.T) {
|
||||
|
||||
func TestSigner_SignHTTP_NoReplaceRequestBody(t *testing.T) {
|
||||
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
req.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
|
||||
s := NewSigner()
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
checks = []
|
||||
Reference in New Issue
Block a user