mirror of
https://github.com/versity/versitygw.git
synced 2026-01-04 19:13:57 +00:00
fix: hoist aws-sdk-go-v2 signer into veristygw for custom modifications
This commit is contained in:
202
aws/LICENSE.txt
Normal file
202
aws/LICENSE.txt
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
3
aws/NOTICE.txt
Normal file
3
aws/NOTICE.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
AWS SDK for Go
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2014-2015 Stripe, Inc.
|
||||
11
aws/README.md
Normal file
11
aws/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# AWS SDK Go v2
|
||||
|
||||
This directory contains code from the [AWS SDK Go v2](https://github.com/aws/aws-sdk-go-v2) repository, modified in accordance with the Apache 2.0 License.
|
||||
|
||||
## Description
|
||||
|
||||
The AWS SDK Go v2 is a collection of libraries and tools that enable developers to build applications that integrate with various AWS services. This directory and below contains modified code from the original repository, tailored to suit versitygw specific requirements.
|
||||
|
||||
## License
|
||||
|
||||
The code in this directory is licensed under the Apache 2.0 License. Please refer to the [LICENSE](./LICENSE) file for more information.
|
||||
45
aws/internal/auth/auth.go
Normal file
45
aws/internal/auth/auth.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/aws/smithy-go/auth"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// HTTPAuthScheme is the SDK's internal implementation of smithyhttp.AuthScheme
|
||||
// for pre-existing implementations where the signer was added to client
|
||||
// config. SDK clients will key off of this type and ensure per-operation
|
||||
// updates to those signers persist on the scheme itself.
|
||||
type HTTPAuthScheme struct {
|
||||
schemeID string
|
||||
signer smithyhttp.Signer
|
||||
}
|
||||
|
||||
var _ smithyhttp.AuthScheme = (*HTTPAuthScheme)(nil)
|
||||
|
||||
// NewHTTPAuthScheme returns an auth scheme instance with the given config.
|
||||
func NewHTTPAuthScheme(schemeID string, signer smithyhttp.Signer) *HTTPAuthScheme {
|
||||
return &HTTPAuthScheme{
|
||||
schemeID: schemeID,
|
||||
signer: signer,
|
||||
}
|
||||
}
|
||||
|
||||
// SchemeID identifies the auth scheme.
|
||||
func (s *HTTPAuthScheme) SchemeID() string {
|
||||
return s.schemeID
|
||||
}
|
||||
|
||||
// IdentityResolver gets the identity resolver for the auth scheme.
|
||||
func (s *HTTPAuthScheme) IdentityResolver(o auth.IdentityResolverOptions) auth.IdentityResolver {
|
||||
return o.GetIdentityResolver(s.schemeID)
|
||||
}
|
||||
|
||||
// Signer gets the signer for the auth scheme.
|
||||
func (s *HTTPAuthScheme) Signer() smithyhttp.Signer {
|
||||
return s.signer
|
||||
}
|
||||
|
||||
// WithSigner returns a new instance of the auth scheme with the updated signer.
|
||||
func (s *HTTPAuthScheme) WithSigner(signer smithyhttp.Signer) *HTTPAuthScheme {
|
||||
return NewHTTPAuthScheme(s.schemeID, signer)
|
||||
}
|
||||
191
aws/internal/auth/scheme.go
Normal file
191
aws/internal/auth/scheme.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
smithy "github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// SigV4 is a constant representing
|
||||
// Authentication Scheme Signature Version 4
|
||||
const SigV4 = "sigv4"
|
||||
|
||||
// SigV4A is a constant representing
|
||||
// Authentication Scheme Signature Version 4A
|
||||
const SigV4A = "sigv4a"
|
||||
|
||||
// SigV4S3Express identifies the S3 S3Express auth scheme.
|
||||
const SigV4S3Express = "sigv4-s3express"
|
||||
|
||||
// None is a constant representing the
|
||||
// None Authentication Scheme
|
||||
const None = "none"
|
||||
|
||||
// SupportedSchemes is a data structure
|
||||
// that indicates the list of supported AWS
|
||||
// authentication schemes
|
||||
var SupportedSchemes = map[string]bool{
|
||||
SigV4: true,
|
||||
SigV4A: true,
|
||||
SigV4S3Express: true,
|
||||
None: true,
|
||||
}
|
||||
|
||||
// AuthenticationScheme is a representation of
|
||||
// AWS authentication schemes
|
||||
type AuthenticationScheme interface {
|
||||
isAuthenticationScheme()
|
||||
}
|
||||
|
||||
// AuthenticationSchemeV4 is a AWS SigV4 representation
|
||||
type AuthenticationSchemeV4 struct {
|
||||
Name string
|
||||
SigningName *string
|
||||
SigningRegion *string
|
||||
DisableDoubleEncoding *bool
|
||||
}
|
||||
|
||||
func (a *AuthenticationSchemeV4) isAuthenticationScheme() {}
|
||||
|
||||
// AuthenticationSchemeV4A is a AWS SigV4A representation
|
||||
type AuthenticationSchemeV4A struct {
|
||||
Name string
|
||||
SigningName *string
|
||||
SigningRegionSet []string
|
||||
DisableDoubleEncoding *bool
|
||||
}
|
||||
|
||||
func (a *AuthenticationSchemeV4A) isAuthenticationScheme() {}
|
||||
|
||||
// AuthenticationSchemeNone is a representation for the none auth scheme
|
||||
type AuthenticationSchemeNone struct{}
|
||||
|
||||
func (a *AuthenticationSchemeNone) isAuthenticationScheme() {}
|
||||
|
||||
// NoAuthenticationSchemesFoundError is used in signaling
|
||||
// that no authentication schemes have been specified.
|
||||
type NoAuthenticationSchemesFoundError struct{}
|
||||
|
||||
func (e *NoAuthenticationSchemesFoundError) Error() string {
|
||||
return fmt.Sprint("No authentication schemes specified.")
|
||||
}
|
||||
|
||||
// UnSupportedAuthenticationSchemeSpecifiedError is used in
|
||||
// signaling that only unsupported authentication schemes
|
||||
// were specified.
|
||||
type UnSupportedAuthenticationSchemeSpecifiedError struct {
|
||||
UnsupportedSchemes []string
|
||||
}
|
||||
|
||||
func (e *UnSupportedAuthenticationSchemeSpecifiedError) Error() string {
|
||||
return fmt.Sprint("Unsupported authentication scheme specified.")
|
||||
}
|
||||
|
||||
// GetAuthenticationSchemes extracts the relevant authentication scheme data
|
||||
// into a custom strongly typed Go data structure.
|
||||
func GetAuthenticationSchemes(p *smithy.Properties) ([]AuthenticationScheme, error) {
|
||||
var result []AuthenticationScheme
|
||||
if !p.Has("authSchemes") {
|
||||
return nil, &NoAuthenticationSchemesFoundError{}
|
||||
}
|
||||
|
||||
authSchemes, _ := p.Get("authSchemes").([]interface{})
|
||||
|
||||
var unsupportedSchemes []string
|
||||
for _, scheme := range authSchemes {
|
||||
authScheme, _ := scheme.(map[string]interface{})
|
||||
|
||||
version := authScheme["name"].(string)
|
||||
switch version {
|
||||
case SigV4, SigV4S3Express:
|
||||
v4Scheme := AuthenticationSchemeV4{
|
||||
Name: version,
|
||||
SigningName: getSigningName(authScheme),
|
||||
SigningRegion: getSigningRegion(authScheme),
|
||||
DisableDoubleEncoding: getDisableDoubleEncoding(authScheme),
|
||||
}
|
||||
result = append(result, AuthenticationScheme(&v4Scheme))
|
||||
case SigV4A:
|
||||
v4aScheme := AuthenticationSchemeV4A{
|
||||
Name: SigV4A,
|
||||
SigningName: getSigningName(authScheme),
|
||||
SigningRegionSet: getSigningRegionSet(authScheme),
|
||||
DisableDoubleEncoding: getDisableDoubleEncoding(authScheme),
|
||||
}
|
||||
result = append(result, AuthenticationScheme(&v4aScheme))
|
||||
case None:
|
||||
noneScheme := AuthenticationSchemeNone{}
|
||||
result = append(result, AuthenticationScheme(&noneScheme))
|
||||
default:
|
||||
unsupportedSchemes = append(unsupportedSchemes, authScheme["name"].(string))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, &UnSupportedAuthenticationSchemeSpecifiedError{
|
||||
UnsupportedSchemes: unsupportedSchemes,
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type disableDoubleEncoding struct{}
|
||||
|
||||
// SetDisableDoubleEncoding sets or modifies the disable double encoding option
|
||||
// on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetDisableDoubleEncoding(ctx context.Context, value bool) context.Context {
|
||||
return middleware.WithStackValue(ctx, disableDoubleEncoding{}, value)
|
||||
}
|
||||
|
||||
// GetDisableDoubleEncoding retrieves the disable double encoding option
|
||||
// from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetDisableDoubleEncoding(ctx context.Context) (value bool, ok bool) {
|
||||
value, ok = middleware.GetStackValue(ctx, disableDoubleEncoding{}).(bool)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func getSigningName(authScheme map[string]interface{}) *string {
|
||||
signingName, ok := authScheme["signingName"].(string)
|
||||
if !ok || signingName == "" {
|
||||
return nil
|
||||
}
|
||||
return &signingName
|
||||
}
|
||||
|
||||
func getSigningRegionSet(authScheme map[string]interface{}) []string {
|
||||
untypedSigningRegionSet, ok := authScheme["signingRegionSet"].([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
signingRegionSet := []string{}
|
||||
for _, item := range untypedSigningRegionSet {
|
||||
signingRegionSet = append(signingRegionSet, item.(string))
|
||||
}
|
||||
return signingRegionSet
|
||||
}
|
||||
|
||||
func getSigningRegion(authScheme map[string]interface{}) *string {
|
||||
signingRegion, ok := authScheme["signingRegion"].(string)
|
||||
if !ok || signingRegion == "" {
|
||||
return nil
|
||||
}
|
||||
return &signingRegion
|
||||
}
|
||||
|
||||
func getDisableDoubleEncoding(authScheme map[string]interface{}) *bool {
|
||||
disableDoubleEncoding, ok := authScheme["disableDoubleEncoding"].(bool)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &disableDoubleEncoding
|
||||
}
|
||||
101
aws/internal/auth/scheme_test.go
Normal file
101
aws/internal/auth/scheme_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
smithy "github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
func TestV4(t *testing.T) {
|
||||
|
||||
propsV4 := smithy.Properties{}
|
||||
|
||||
propsV4.Set("authSchemes", interface{}([]interface{}{
|
||||
map[string]interface{}{
|
||||
"disableDoubleEncoding": true,
|
||||
"name": "sigv4",
|
||||
"signingName": "s3",
|
||||
"signingRegion": "us-west-2",
|
||||
},
|
||||
}))
|
||||
|
||||
result, err := GetAuthenticationSchemes(&propsV4)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
_, ok := result[0].(AuthenticationScheme)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationScheme. %v", result[0])
|
||||
}
|
||||
|
||||
v4Scheme, ok := result[0].(*AuthenticationSchemeV4)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4. %v", result[0])
|
||||
}
|
||||
|
||||
if v4Scheme.Name != "sigv4" {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4 signer version name")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestV4A(t *testing.T) {
|
||||
|
||||
propsV4A := smithy.Properties{}
|
||||
|
||||
propsV4A.Set("authSchemes", []interface{}{
|
||||
map[string]interface{}{
|
||||
"disableDoubleEncoding": true,
|
||||
"name": "sigv4a",
|
||||
"signingName": "s3",
|
||||
"signingRegionSet": []string{"*"},
|
||||
},
|
||||
})
|
||||
|
||||
result, err := GetAuthenticationSchemes(&propsV4A)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
_, ok := result[0].(AuthenticationScheme)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationScheme. %v", result[0])
|
||||
}
|
||||
|
||||
v4AScheme, ok := result[0].(*AuthenticationSchemeV4A)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4A. %v", result[0])
|
||||
}
|
||||
|
||||
if v4AScheme.Name != "sigv4a" {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4A signer version name")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestV4S3Express(t *testing.T) {
|
||||
props := smithy.Properties{}
|
||||
props.Set("authSchemes", []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": SigV4S3Express,
|
||||
"signingName": "s3",
|
||||
"signingRegion": "us-east-1",
|
||||
"disableDoubleEncoding": true,
|
||||
},
|
||||
})
|
||||
|
||||
result, err := GetAuthenticationSchemes(&props)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error, got %v", err)
|
||||
}
|
||||
|
||||
scheme, ok := result[0].(*AuthenticationSchemeV4)
|
||||
if !ok {
|
||||
t.Fatalf("Did not get expected AuthenticationSchemeV4. %v", result[0])
|
||||
}
|
||||
|
||||
if scheme.Name != SigV4S3Express {
|
||||
t.Fatalf("expected %s, got %s", SigV4S3Express, scheme.Name)
|
||||
}
|
||||
}
|
||||
43
aws/internal/auth/smithy/bearer_token_adapter.go
Normal file
43
aws/internal/auth/smithy/bearer_token_adapter.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
)
|
||||
|
||||
// BearerTokenAdapter adapts smithy bearer.Token to smithy auth.Identity.
|
||||
type BearerTokenAdapter struct {
|
||||
Token bearer.Token
|
||||
}
|
||||
|
||||
var _ auth.Identity = (*BearerTokenAdapter)(nil)
|
||||
|
||||
// Expiration returns the time of expiration for the token.
|
||||
func (v *BearerTokenAdapter) Expiration() time.Time {
|
||||
return v.Token.Expires
|
||||
}
|
||||
|
||||
// BearerTokenProviderAdapter adapts smithy bearer.TokenProvider to smithy
|
||||
// auth.IdentityResolver.
|
||||
type BearerTokenProviderAdapter struct {
|
||||
Provider bearer.TokenProvider
|
||||
}
|
||||
|
||||
var _ (auth.IdentityResolver) = (*BearerTokenProviderAdapter)(nil)
|
||||
|
||||
// GetIdentity retrieves a bearer token using the underlying provider.
|
||||
func (v *BearerTokenProviderAdapter) GetIdentity(ctx context.Context, _ smithy.Properties) (
|
||||
auth.Identity, error,
|
||||
) {
|
||||
token, err := v.Provider.RetrieveBearerToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get token: %w", err)
|
||||
}
|
||||
|
||||
return &BearerTokenAdapter{Token: token}, nil
|
||||
}
|
||||
35
aws/internal/auth/smithy/bearer_token_signer_adapter.go
Normal file
35
aws/internal/auth/smithy/bearer_token_signer_adapter.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// BearerTokenSignerAdapter adapts smithy bearer.Signer to smithy http
|
||||
// auth.Signer.
|
||||
type BearerTokenSignerAdapter struct {
|
||||
Signer bearer.Signer
|
||||
}
|
||||
|
||||
var _ (smithyhttp.Signer) = (*BearerTokenSignerAdapter)(nil)
|
||||
|
||||
// SignRequest signs the request with the provided bearer token.
|
||||
func (v *BearerTokenSignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, identity auth.Identity, _ smithy.Properties) error {
|
||||
ca, ok := identity.(*BearerTokenAdapter)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected identity type: %T", identity)
|
||||
}
|
||||
|
||||
signed, err := v.Signer.SignWithBearerToken(ctx, ca.Token, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign request: %w", err)
|
||||
}
|
||||
|
||||
*r = *signed.(*smithyhttp.Request)
|
||||
return nil
|
||||
}
|
||||
46
aws/internal/auth/smithy/credentials_adapter.go
Normal file
46
aws/internal/auth/smithy/credentials_adapter.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
)
|
||||
|
||||
// CredentialsAdapter adapts aws.Credentials to auth.Identity.
|
||||
type CredentialsAdapter struct {
|
||||
Credentials aws.Credentials
|
||||
}
|
||||
|
||||
var _ auth.Identity = (*CredentialsAdapter)(nil)
|
||||
|
||||
// Expiration returns the time of expiration for the credentials.
|
||||
func (v *CredentialsAdapter) Expiration() time.Time {
|
||||
return v.Credentials.Expires
|
||||
}
|
||||
|
||||
// CredentialsProviderAdapter adapts aws.CredentialsProvider to auth.IdentityResolver.
|
||||
type CredentialsProviderAdapter struct {
|
||||
Provider aws.CredentialsProvider
|
||||
}
|
||||
|
||||
var _ (auth.IdentityResolver) = (*CredentialsProviderAdapter)(nil)
|
||||
|
||||
// GetIdentity retrieves AWS credentials using the underlying provider.
|
||||
func (v *CredentialsProviderAdapter) GetIdentity(ctx context.Context, _ smithy.Properties) (
|
||||
auth.Identity, error,
|
||||
) {
|
||||
if v.Provider == nil {
|
||||
return &CredentialsAdapter{Credentials: aws.Credentials{}}, nil
|
||||
}
|
||||
|
||||
creds, err := v.Provider.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get credentials: %w", err)
|
||||
}
|
||||
|
||||
return &CredentialsAdapter{Credentials: creds}, nil
|
||||
}
|
||||
2
aws/internal/auth/smithy/smithy.go
Normal file
2
aws/internal/auth/smithy/smithy.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package smithy adapts concrete AWS auth and signing types to the generic smithy versions.
|
||||
package smithy
|
||||
53
aws/internal/auth/smithy/v4signer_adapter.go
Normal file
53
aws/internal/auth/smithy/v4signer_adapter.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package smithy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
"github.com/aws/smithy-go"
|
||||
"github.com/aws/smithy-go/auth"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
)
|
||||
|
||||
// V4SignerAdapter adapts v4.HTTPSigner to smithy http.Signer.
|
||||
type V4SignerAdapter struct {
|
||||
Signer v4.HTTPSigner
|
||||
Logger logging.Logger
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
var _ (smithyhttp.Signer) = (*V4SignerAdapter)(nil)
|
||||
|
||||
// SignRequest signs the request with the provided identity.
|
||||
func (v *V4SignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, identity auth.Identity, props smithy.Properties) error {
|
||||
ca, ok := identity.(*CredentialsAdapter)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected identity type: %T", identity)
|
||||
}
|
||||
|
||||
name, ok := smithyhttp.GetSigV4SigningName(&props)
|
||||
if !ok {
|
||||
return fmt.Errorf("sigv4 signing name is required")
|
||||
}
|
||||
|
||||
region, ok := smithyhttp.GetSigV4SigningRegion(&props)
|
||||
if !ok {
|
||||
return fmt.Errorf("sigv4 signing region is required")
|
||||
}
|
||||
|
||||
hash := v4.GetPayloadHash(ctx)
|
||||
err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, region, sdk.NowTime(), func(o *v4.SignerOptions) {
|
||||
o.DisableURIPathEscaping, _ = smithyhttp.GetDisableDoubleEncoding(&props)
|
||||
|
||||
o.Logger = v.Logger
|
||||
o.LogSigning = v.LogSigning
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign http: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
222
aws/internal/awstesting/assert.go
Normal file
222
aws/internal/awstesting/assert.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Match is a testing helper to test for testing error by comparing expected
|
||||
// with a regular expression.
|
||||
func Match(t *testing.T, regex, expected string) {
|
||||
t.Helper()
|
||||
|
||||
if !regexp.MustCompile(regex).Match([]byte(expected)) {
|
||||
t.Errorf("%q\n\tdoes not match /%s/", expected, regex)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertURL verifies the expected URL is matches the actual.
|
||||
func AssertURL(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectURL, err := url.Parse(expect)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected URL", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
actualURL, err := url.Parse(actual)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual URL", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
|
||||
equal(t, expectURL.Host, actualURL.Host, msgAndArgs...)
|
||||
equal(t, expectURL.Scheme, actualURL.Scheme, msgAndArgs...)
|
||||
equal(t, expectURL.Path, actualURL.Path, msgAndArgs...)
|
||||
|
||||
return AssertQuery(t, expectURL.Query().Encode(), actualURL.Query().Encode(), msgAndArgs...)
|
||||
}
|
||||
|
||||
var queryMapKey = regexp.MustCompile(`(.*?)\.[0-9]+\.key`)
|
||||
|
||||
// AssertQuery verifies the expect HTTP query string matches the actual.
|
||||
func AssertQuery(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectQ, err := url.ParseQuery(expect)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected Query", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
actualQ, err := url.ParseQuery(actual)
|
||||
if err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual Query", err, msgAndArgs))
|
||||
return false
|
||||
}
|
||||
|
||||
// Make sure the keys are the same
|
||||
if !equal(t, queryValueKeys(expectQ), queryValueKeys(actualQ), msgAndArgs...) {
|
||||
return false
|
||||
}
|
||||
|
||||
keys := map[string][]string{}
|
||||
for key, v := range expectQ {
|
||||
if queryMapKey.Match([]byte(key)) {
|
||||
submatch := queryMapKey.FindStringSubmatch(key)
|
||||
keys[submatch[1]] = append(keys[submatch[1]], v...)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range keys {
|
||||
// clear all keys that have prefix
|
||||
for key := range expectQ {
|
||||
if strings.HasPrefix(key, k) {
|
||||
delete(expectQ, key)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(v)
|
||||
for i, value := range v {
|
||||
expectQ[fmt.Sprintf("%s.%d.key", k, i+1)] = []string{value}
|
||||
}
|
||||
}
|
||||
|
||||
for k, expectQVals := range expectQ {
|
||||
sort.Strings(expectQVals)
|
||||
actualQVals := actualQ[k]
|
||||
sort.Strings(actualQVals)
|
||||
if !equal(t, expectQVals, actualQVals, msgAndArgs...) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertJSON verifies that the expect json string matches the actual.
|
||||
func AssertJSON(t *testing.T, expect, actual string, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
expectVal := map[string]interface{}{}
|
||||
if err := json.Unmarshal([]byte(expect), &expectVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected JSON", err, msgAndArgs...))
|
||||
return false
|
||||
}
|
||||
|
||||
actualVal := map[string]interface{}{}
|
||||
if err := json.Unmarshal([]byte(actual), &actualVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual JSON", err, msgAndArgs...))
|
||||
return false
|
||||
}
|
||||
|
||||
return equal(t, expectVal, actualVal, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertXML verifies that the expect xml string matches the actual.
|
||||
func AssertXML(t *testing.T, expect, actual string, container interface{}, msgAndArgs ...interface{}) bool {
|
||||
expectVal := container
|
||||
if err := xml.Unmarshal([]byte(expect), &expectVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse expected XML", err, msgAndArgs...))
|
||||
}
|
||||
|
||||
actualVal := container
|
||||
if err := xml.Unmarshal([]byte(actual), &actualVal); err != nil {
|
||||
t.Errorf(errMsg("unable to parse actual XML", err, msgAndArgs...))
|
||||
}
|
||||
return equal(t, expectVal, actualVal, msgAndArgs...)
|
||||
}
|
||||
|
||||
// DidPanic returns if the function paniced and returns true if the function paniced.
|
||||
func DidPanic(fn func()) (bool, interface{}) {
|
||||
var paniced bool
|
||||
var msg interface{}
|
||||
func() {
|
||||
defer func() {
|
||||
if msg = recover(); msg != nil {
|
||||
paniced = true
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}()
|
||||
|
||||
return paniced, msg
|
||||
}
|
||||
|
||||
// objectsAreEqual determines if two objects are considered equal.
|
||||
//
|
||||
// This function does no assertion of any kind.
|
||||
//
|
||||
// Based on github.com/stretchr/testify/assert.ObjectsAreEqual
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func objectsAreEqual(expected, actual interface{}) bool {
|
||||
if expected == nil || actual == nil {
|
||||
return expected == actual
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(expected, actual)
|
||||
}
|
||||
|
||||
// Equal asserts that two objects are equal.
|
||||
//
|
||||
// assert.Equal(t, 123, 123, "123 and 123 should be equal")
|
||||
//
|
||||
// Returns whether the assertion was successful (true) or not (false).
|
||||
//
|
||||
// Based on github.com/stretchr/testify/assert.Equal
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func equal(t *testing.T, expected, actual interface{}, msgAndArgs ...interface{}) bool {
|
||||
t.Helper()
|
||||
|
||||
if !objectsAreEqual(expected, actual) {
|
||||
t.Errorf("%s\n%s", messageFromMsgAndArgs(msgAndArgs),
|
||||
SprintExpectActual(expected, actual))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func errMsg(baseMsg string, err error, msgAndArgs ...interface{}) string {
|
||||
message := messageFromMsgAndArgs(msgAndArgs)
|
||||
if message != "" {
|
||||
message += ", "
|
||||
}
|
||||
return fmt.Sprintf("%s%s, %v", message, baseMsg, err)
|
||||
}
|
||||
|
||||
// Based on github.com/stretchr/testify/assert.messageFromMsgAndArgs
|
||||
// Copied locally to prevent non-test build dependencies on testify
|
||||
func messageFromMsgAndArgs(msgAndArgs []interface{}) string {
|
||||
if len(msgAndArgs) == 0 || msgAndArgs == nil {
|
||||
return ""
|
||||
}
|
||||
if len(msgAndArgs) == 1 {
|
||||
return msgAndArgs[0].(string)
|
||||
}
|
||||
if len(msgAndArgs) > 1 {
|
||||
return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func queryValueKeys(v url.Values) []string {
|
||||
keys := make([]string, 0, len(v))
|
||||
for k := range v {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
// SprintExpectActual returns a string for test failure cases when the actual
|
||||
// value is not the same as the expected.
|
||||
func SprintExpectActual(expect, actual interface{}) string {
|
||||
return fmt.Sprintf("expect: %+v\nactual: %+v\n", expect, actual)
|
||||
}
|
||||
89
aws/internal/awstesting/assert_test.go
Normal file
89
aws/internal/awstesting/assert_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package awstesting_test
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"testing"
|
||||
|
||||
"github.com/versity/versitygw/aws/internal/awstesting"
|
||||
)
|
||||
|
||||
func TestAssertJSON(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
}{
|
||||
{
|
||||
e: `{"RecursiveStruct":{"RecursiveMap":{"foo":{"NoRecurse":"foo"},"bar":{"NoRecurse":"bar"}}}}`,
|
||||
a: `{"RecursiveStruct":{"RecursiveMap":{"bar":{"NoRecurse":"bar"},"foo":{"NoRecurse":"foo"}}}}`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
mockT := &testing.T{}
|
||||
if awstesting.AssertJSON(mockT, c.e, c.a) != c.asserts {
|
||||
t.Error("Assert JSON result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertXML(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
container struct {
|
||||
XMLName xml.Name `xml:"OperationRequest"`
|
||||
NS string `xml:"xmlns,attr"`
|
||||
RecursiveStruct struct {
|
||||
RecursiveMap struct {
|
||||
Entries []struct {
|
||||
XMLName xml.Name `xml:"entries"`
|
||||
Key string `xml:"key"`
|
||||
Value struct {
|
||||
XMLName xml.Name `xml:"value"`
|
||||
NoRecurse string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}{
|
||||
{
|
||||
e: `<OperationRequest xmlns="https://foo/"><RecursiveStruct xmlns="https://foo/"><RecursiveMap xmlns="https://foo/"><entry xmlns="https://foo/"><key xmlns="https://foo/">foo</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">foo</NoRecurse></value></entry><entry xmlns="https://foo/"><key xmlns="https://foo/">bar</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">bar</NoRecurse></value></entry></RecursiveMap></RecursiveStruct></OperationRequest>`,
|
||||
a: `<OperationRequest xmlns="https://foo/"><RecursiveStruct xmlns="https://foo/"><RecursiveMap xmlns="https://foo/"><entry xmlns="https://foo/"><key xmlns="https://foo/">bar</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">bar</NoRecurse></value></entry><entry xmlns="https://foo/"><key xmlns="https://foo/">foo</key><value xmlns="https://foo/"><NoRecurse xmlns="https://foo/">foo</NoRecurse></value></entry></RecursiveMap></RecursiveStruct></OperationRequest>`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
// mockT := &testing.T{}
|
||||
if awstesting.AssertXML(t, c.e, c.a, c.container) != c.asserts {
|
||||
t.Error("Assert XML result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssertQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
e, a string
|
||||
asserts bool
|
||||
}{
|
||||
{
|
||||
e: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
a: `Action=OperationName&Version=2014-01-01&Foo=val2&Bar=val3`,
|
||||
asserts: false,
|
||||
},
|
||||
{
|
||||
e: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
a: `Action=OperationName&Version=2014-01-01&Foo=val1&Bar=val2`,
|
||||
asserts: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, c := range cases {
|
||||
mockT := &testing.T{}
|
||||
if awstesting.AssertQuery(mockT, c.e, c.a) != c.asserts {
|
||||
t.Error("Assert Query result was not expected.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
291
aws/internal/awstesting/certificate_utils.go
Normal file
291
aws/internal/awstesting/certificate_utils.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// TLSBundleCA is the CA PEM
|
||||
TLSBundleCA []byte
|
||||
|
||||
// TLSBundleCert is the Server PEM
|
||||
TLSBundleCert []byte
|
||||
|
||||
// TLSBundleKey is the Server private key PEM
|
||||
TLSBundleKey []byte
|
||||
|
||||
// ClientTLSCert is the Client PEM
|
||||
ClientTLSCert []byte
|
||||
|
||||
// ClientTLSKey is the Client private key PEM
|
||||
ClientTLSKey []byte
|
||||
)
|
||||
|
||||
func init() {
|
||||
caPEM, _, caCert, caPrivKey, err := generateRootCA()
|
||||
if err != nil {
|
||||
panic("failed to generate testing root CA, " + err.Error())
|
||||
}
|
||||
TLSBundleCA = caPEM
|
||||
|
||||
serverCertPEM, serverCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
|
||||
if err != nil {
|
||||
panic("failed to generate testing server cert, " + err.Error())
|
||||
}
|
||||
TLSBundleCert = serverCertPEM
|
||||
TLSBundleKey = serverCertPrivKeyPEM
|
||||
|
||||
clientCertPEM, clientCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
|
||||
if err != nil {
|
||||
panic("failed to generate testing client cert, " + err.Error())
|
||||
}
|
||||
ClientTLSCert = clientCertPEM
|
||||
ClientTLSKey = clientCertPrivKeyPEM
|
||||
}
|
||||
|
||||
func generateRootCA() (
|
||||
caPEM, caPrivKeyPEM []byte, caCert *x509.Certificate, caPrivKey *rsa.PrivateKey, err error,
|
||||
) {
|
||||
caCert = &x509.Certificate{
|
||||
SerialNumber: big.NewInt(42),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"US"},
|
||||
Organization: []string{"AWS SDK for Go Test Certificate"},
|
||||
CommonName: "Test Root CA",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
// Create CA private and public key
|
||||
caPrivKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed generate CA RSA key, %w", err)
|
||||
}
|
||||
|
||||
// Create CA certificate
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivKey.PublicKey, caPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed generate CA certificate, %w", err)
|
||||
}
|
||||
|
||||
// PEM encode CA certificate and private key
|
||||
var caPEMBuf bytes.Buffer
|
||||
pem.Encode(&caPEMBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caBytes,
|
||||
})
|
||||
|
||||
var caPrivKeyPEMBuf bytes.Buffer
|
||||
pem.Encode(&caPrivKeyPEMBuf, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
|
||||
})
|
||||
|
||||
return caPEMBuf.Bytes(), caPrivKeyPEMBuf.Bytes(), caCert, caPrivKey, nil
|
||||
}
|
||||
|
||||
func generateLocalCert(parentCert *x509.Certificate, parentPrivKey *rsa.PrivateKey) (
|
||||
certPEM, certPrivKeyPEM []byte, err error,
|
||||
) {
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(42),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"US"},
|
||||
Organization: []string{"AWS SDK for Go Test Certificate"},
|
||||
CommonName: "Test Root CA",
|
||||
},
|
||||
IPAddresses: []net.IP{
|
||||
net.IPv4(127, 0, 0, 1),
|
||||
net.IPv6loopback,
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
// Create server private and public key
|
||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate server RSA private key, %w", err)
|
||||
}
|
||||
|
||||
// Create server certificate
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, cert, parentCert, &certPrivKey.PublicKey, parentPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate server certificate, %w", err)
|
||||
}
|
||||
|
||||
// PEM encode certificate and private key
|
||||
var certPEMBuf bytes.Buffer
|
||||
pem.Encode(&certPEMBuf, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
})
|
||||
|
||||
var certPrivKeyPEMBuf bytes.Buffer
|
||||
pem.Encode(&certPrivKeyPEMBuf, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||
})
|
||||
|
||||
return certPEMBuf.Bytes(), certPrivKeyPEMBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
// NewTLSClientCertServer creates a new HTTP test server initialize to require
|
||||
// HTTP clients authenticate with TLS client certificates.
|
||||
func NewTLSClientCertServer(handler http.Handler) (*httptest.Server, error) {
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
|
||||
if server.TLS == nil {
|
||||
server.TLS = &tls.Config{}
|
||||
}
|
||||
server.TLS.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
|
||||
if server.TLS.ClientCAs == nil {
|
||||
server.TLS.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
certPem := append(ClientTLSCert, ClientTLSKey...)
|
||||
if ok := server.TLS.ClientCAs.AppendCertsFromPEM(certPem); !ok {
|
||||
return nil, fmt.Errorf("failed to append client certs")
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// CreateClientTLSCertFiles returns a set of temporary files for the client
|
||||
// certificate and key files.
|
||||
func CreateClientTLSCertFiles() (cert, key string, err error) {
|
||||
cert, err = createTmpFile(ClientTLSCert)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
key, err = createTmpFile(ClientTLSKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return cert, key, nil
|
||||
}
|
||||
|
||||
func availableLocalAddr(ip string) (v string, err error) {
|
||||
l, err := net.Listen("tcp", ip+":0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := l.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
} else if closeErr != nil {
|
||||
err = fmt.Errorf("ip listener close error: %v, original error: %w", closeErr, err)
|
||||
}
|
||||
}()
|
||||
|
||||
return l.Addr().String(), nil
|
||||
}
|
||||
|
||||
// CreateTLSServer will create the TLS server on an open port using the
|
||||
// certificate and key. The address will be returned that the server is running on.
|
||||
func CreateTLSServer(cert, key string, mux *http.ServeMux) (string, error) {
|
||||
addr, err := availableLocalAddr("127.0.0.1")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if mux == nil {
|
||||
mux = http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := http.ListenAndServeTLS(addr, cert, key, mux); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 60; i++ {
|
||||
if _, err := http.Get("https://" + addr); err != nil && !strings.Contains(err.Error(), "connection refused") {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
return "https://" + addr, nil
|
||||
}
|
||||
|
||||
// CreateTLSBundleFiles returns the temporary filenames for the certificate
|
||||
// key, and CA PEM content. These files should be deleted when no longer
|
||||
// needed. CleanupTLSBundleFiles can be used for this cleanup.
|
||||
func CreateTLSBundleFiles() (cert, key, ca string, err error) {
|
||||
cert, err = createTmpFile(TLSBundleCert)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
key, err = createTmpFile(TLSBundleKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
ca, err = createTmpFile(TLSBundleCA)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return cert, key, ca, nil
|
||||
}
|
||||
|
||||
// CleanupTLSBundleFiles takes variadic list of files to be deleted.
|
||||
func CleanupTLSBundleFiles(files ...string) error {
|
||||
for _, file := range files {
|
||||
if err := os.Remove(file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTmpFile(b []byte) (string, error) {
|
||||
bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = bundleFile.Write(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer bundleFile.Close()
|
||||
return bundleFile.Name(), nil
|
||||
}
|
||||
11
aws/internal/awstesting/discard.go
Normal file
11
aws/internal/awstesting/discard.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package awstesting
|
||||
|
||||
// DiscardAt is an io.WriteAt that discards
|
||||
// the requested bytes to be written
|
||||
type DiscardAt struct{}
|
||||
|
||||
// WriteAt discards the given []byte slice and returns len(p) bytes
|
||||
// as having been written at the given offset. It will never return an error.
|
||||
func (d DiscardAt) WriteAt(p []byte, off int64) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
12
aws/internal/awstesting/endless_reader.go
Normal file
12
aws/internal/awstesting/endless_reader.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package awstesting
|
||||
|
||||
// EndlessReader is an io.Reader that will always return
|
||||
// that bytes have been read.
|
||||
type EndlessReader struct{}
|
||||
|
||||
// Read will report that it has read len(p) bytes in p.
|
||||
// The content in the []byte will be unmodified.
|
||||
// This will never return an error.
|
||||
func (e EndlessReader) Read(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
43
aws/internal/awstesting/sandbox/Dockerfile.golang-tip
Normal file
43
aws/internal/awstesting/sandbox/Dockerfile.golang-tip
Normal file
@@ -0,0 +1,43 @@
|
||||
# Based on docker-library's golang 1.6 alpine and wheezy docker files.
|
||||
# https://github.com/docker-library/golang/blob/master/1.6/alpine/Dockerfile
|
||||
# https://github.com/docker-library/golang/blob/master/1.6/wheezy/Dockerfile
|
||||
FROM buildpack-deps:buster-scm
|
||||
|
||||
ENV GOLANG_SRC_REPO_URL https://github.com/golang/go
|
||||
|
||||
# as of 1.20 Go 1.17 is required to bootstrap
|
||||
# see https://github.com/golang/go/issues/44505
|
||||
ENV GOLANG_BOOTSTRAP_URL https://go.dev/dl/go1.17.13.linux-amd64.tar.gz
|
||||
ENV GOLANG_BOOTSTRAP_SHA256 4cdd2bc664724dc7db94ad51b503512c5ae7220951cac568120f64f8e94399fc
|
||||
ENV GOLANG_BOOTSTRAP_PATH /usr/local/bootstrap
|
||||
|
||||
# gcc for cgo
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
g++ \
|
||||
gcc \
|
||||
libc6-dev \
|
||||
make \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Setup the Bootstrap
|
||||
RUN mkdir -p "$GOLANG_BOOTSTRAP_PATH" \
|
||||
&& curl -fsSL "$GOLANG_BOOTSTRAP_URL" -o golang.tar.gz \
|
||||
&& echo "$GOLANG_BOOTSTRAP_SHA256 golang.tar.gz" | sha256sum -c - \
|
||||
&& tar -C "$GOLANG_BOOTSTRAP_PATH" -xzf golang.tar.gz \
|
||||
&& rm golang.tar.gz
|
||||
|
||||
# Get and build Go tip
|
||||
RUN export GOROOT_BOOTSTRAP=$GOLANG_BOOTSTRAP_PATH/go \
|
||||
&& git clone "$GOLANG_SRC_REPO_URL" /usr/local/go \
|
||||
&& cd /usr/local/go/src \
|
||||
&& ./make.bash \
|
||||
&& rm -rf "$GOLANG_BOOTSTRAP_PATH" /usr/local/go/pkg/bootstrap
|
||||
|
||||
# Build Go workspace and environment
|
||||
ENV GOPATH /go
|
||||
ENV PATH $GOPATH/bin:/usr/local/go/bin:$PATH
|
||||
RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" \
|
||||
&& chmod -R 777 "$GOPATH"
|
||||
|
||||
WORKDIR $GOPATH
|
||||
16
aws/internal/awstesting/sandbox/Dockerfile.test.gotip
Normal file
16
aws/internal/awstesting/sandbox/Dockerfile.test.gotip
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM aws-golang:tip
|
||||
|
||||
ENV GOPROXY=direct
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& wget -O- https://apt.corretto.aws/corretto.key | apt-key add - \
|
||||
&& add-apt-repository 'deb https://apt.corretto.aws stable main' \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
vim \
|
||||
java-17-amazon-corretto-jdk \
|
||||
&& rm -rf /var/list/apt/lists/*
|
||||
|
||||
ADD . /go/src/github.com/aws/aws-sdk-go-v2
|
||||
WORKDIR /go/src/github.com/aws/aws-sdk-go-v2
|
||||
CMD ["make", "unit"]
|
||||
18
aws/internal/awstesting/sandbox/Dockerfile.test.goversion
Normal file
18
aws/internal/awstesting/sandbox/Dockerfile.test.goversion
Normal file
@@ -0,0 +1,18 @@
|
||||
ARG GO_VERSION
|
||||
FROM golang:${GO_VERSION}
|
||||
|
||||
ENV GOPROXY=direct
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
&& wget -O- https://apt.corretto.aws/corretto.key | apt-key add - \
|
||||
&& add-apt-repository 'deb https://apt.corretto.aws stable main' \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
vim \
|
||||
java-17-amazon-corretto-jdk \
|
||||
&& rm -rf /var/list/apt/lists/*
|
||||
|
||||
ADD . /go/src/github.com/aws/aws-sdk-go-v2
|
||||
|
||||
WORKDIR /go/src/github.com/aws/aws-sdk-go-v2
|
||||
CMD ["make", "unit"]
|
||||
61
aws/internal/awstesting/unit/unit.go
Normal file
61
aws/internal/awstesting/unit/unit.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Package unit performs initialization and validation for unit tests
|
||||
package unit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config = aws.Config{}
|
||||
config.Region = "mock-region"
|
||||
config.Credentials = StubCredentialsProvider{}
|
||||
}
|
||||
|
||||
// StubCredentialsProvider provides a stub credential provider that returns
|
||||
// static credentials that never expire.
|
||||
type StubCredentialsProvider struct{}
|
||||
|
||||
// Retrieve satisfies the CredentialsProvider interface. Returns stub
|
||||
// credential value, and never error.
|
||||
func (StubCredentialsProvider) Retrieve(context.Context) (aws.Credentials, error) {
|
||||
return aws.Credentials{
|
||||
AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "SESSION",
|
||||
Source: "unit test credentials",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var config aws.Config
|
||||
|
||||
// Config returns a copy of the mock configuration for unit tests.
|
||||
func Config() aws.Config { return config.Copy() }
|
||||
|
||||
// RSAPrivateKey is used for testing functionality that requires some
|
||||
// sort of private key. Taken from crypto/rsa/rsa_test.go
|
||||
//
|
||||
// Credit to golang 1.11
|
||||
var RSAPrivateKey = &rsa.PrivateKey{
|
||||
PublicKey: rsa.PublicKey{
|
||||
N: fromBase10("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557"),
|
||||
E: 3,
|
||||
},
|
||||
D: fromBase10("9542755287494004433998723259516013739278699355114572217325597900889416163458809501304132487555642811888150937392013824621448709836142886006653296025093941418628992648429798282127303704957273845127141852309016655778568546006839666463451542076964744073572349705538631742281931858219480985907271975884773482372966847639853897890615456605598071088189838676728836833012254065983259638538107719766738032720239892094196108713378822882383694456030043492571063441943847195939549773271694647657549658603365629458610273821292232646334717612674519997533901052790334279661754176490593041941863932308687197618671528035670452762731"),
|
||||
Primes: []*big.Int{
|
||||
fromBase10("130903255182996722426771613606077755295583329135067340152947172868415809027537376306193179624298874215608270802054347609836776473930072411958753044562214537013874103802006369634761074377213995983876788718033850153719421695468704276694983032644416930879093914927146648402139231293035971427838068945045019075433"),
|
||||
fromBase10("109348945610485453577574767652527472924289229538286649661240938988020367005475727988253438647560958573506159449538793540472829815903949343191091817779240101054552748665267574271163617694640513549693841337820602726596756351006149518830932261246698766355347898158548465400674856021497190430791824869615170301029"),
|
||||
},
|
||||
}
|
||||
|
||||
// Taken from crypto/rsa/rsa_test.go
|
||||
//
|
||||
// Credit to golang 1.11
|
||||
func fromBase10(base10 string) *big.Int {
|
||||
i, ok := new(big.Int).SetString(base10, 10)
|
||||
if !ok {
|
||||
panic("bad number: " + base10)
|
||||
}
|
||||
return i
|
||||
}
|
||||
201
aws/internal/awstesting/util.go
Normal file
201
aws/internal/awstesting/util.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package awstesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
// ZeroReader is a io.Reader which will always write zeros to the byte slice provided.
|
||||
type ZeroReader struct{}
|
||||
|
||||
// Read fills the provided byte slice with zeros returning the number of bytes written.
|
||||
func (r *ZeroReader) Read(b []byte) (int, error) {
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = 0
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// ReadCloser is a io.ReadCloser for unit testing.
|
||||
// Designed to test for leaks and whether a handle has
|
||||
// been closed
|
||||
type ReadCloser struct {
|
||||
Size int
|
||||
Closed bool
|
||||
set bool
|
||||
FillData func(bool, []byte, int, int)
|
||||
}
|
||||
|
||||
// Read will call FillData and fill it with whatever data needed.
|
||||
// Decrements the size until zero, then return io.EOF.
|
||||
func (r *ReadCloser) Read(b []byte) (int, error) {
|
||||
if r.Closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
delta := len(b)
|
||||
if delta > r.Size {
|
||||
delta = r.Size
|
||||
}
|
||||
r.Size -= delta
|
||||
|
||||
for i := 0; i < delta; i++ {
|
||||
b[i] = 'a'
|
||||
}
|
||||
|
||||
if r.FillData != nil {
|
||||
r.FillData(r.set, b, r.Size, delta)
|
||||
}
|
||||
r.set = true
|
||||
|
||||
if r.Size > 0 {
|
||||
return delta, nil
|
||||
}
|
||||
return delta, io.EOF
|
||||
}
|
||||
|
||||
// Close sets Closed to true and returns no error
|
||||
func (r *ReadCloser) Close() error {
|
||||
r.Closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// A FakeContext provides a simple stub implementation of a Context
|
||||
type FakeContext struct {
|
||||
Error error
|
||||
DoneCh chan struct{}
|
||||
}
|
||||
|
||||
// Deadline always will return not set
|
||||
func (c *FakeContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Done returns a read channel for listening to the Done event
|
||||
func (c *FakeContext) Done() <-chan struct{} {
|
||||
return c.DoneCh
|
||||
}
|
||||
|
||||
// Err returns the error, is nil if not set.
|
||||
func (c *FakeContext) Err() error {
|
||||
return c.Error
|
||||
}
|
||||
|
||||
// Value ignores the Value and always returns nil
|
||||
func (c *FakeContext) Value(key interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StashEnv stashes the current environment variables except variables listed in envToKeepx
|
||||
// Returns an function to pop out old environment
|
||||
func StashEnv(envToKeep ...string) []string {
|
||||
if runtime.GOOS == "windows" {
|
||||
envToKeep = append(envToKeep, "ComSpec")
|
||||
envToKeep = append(envToKeep, "SYSTEM32")
|
||||
envToKeep = append(envToKeep, "SYSTEMROOT")
|
||||
}
|
||||
envToKeep = append(envToKeep, "PATH", "HOME", "USERPROFILE")
|
||||
extraEnv := getEnvs(envToKeep)
|
||||
originalEnv := os.Environ()
|
||||
os.Clearenv() // clear env
|
||||
for key, val := range extraEnv {
|
||||
os.Setenv(key, val)
|
||||
}
|
||||
return originalEnv
|
||||
}
|
||||
|
||||
// PopEnv takes the list of the environment values and injects them into the
|
||||
// process's environment variable data. Clears any existing environment values
|
||||
// that may already exist.
|
||||
func PopEnv(env []string) {
|
||||
os.Clearenv()
|
||||
|
||||
for _, e := range env {
|
||||
p := strings.SplitN(e, "=", 2)
|
||||
k, v := p[0], ""
|
||||
if len(p) > 1 {
|
||||
v = p[1]
|
||||
}
|
||||
os.Setenv(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// MockCredentialsProvider is a type that can be used to mock out credentials
|
||||
// providers
|
||||
type MockCredentialsProvider struct {
|
||||
RetrieveFn func(ctx context.Context) (aws.Credentials, error)
|
||||
InvalidateFn func()
|
||||
}
|
||||
|
||||
// Retrieve calls the RetrieveFn
|
||||
func (p MockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
|
||||
return p.RetrieveFn(ctx)
|
||||
}
|
||||
|
||||
// Invalidate calls the InvalidateFn
|
||||
func (p MockCredentialsProvider) Invalidate() {
|
||||
p.InvalidateFn()
|
||||
}
|
||||
|
||||
func getEnvs(envs []string) map[string]string {
|
||||
extraEnvs := make(map[string]string)
|
||||
for _, env := range envs {
|
||||
if val, ok := os.LookupEnv(env); ok && len(val) > 0 {
|
||||
extraEnvs[env] = val
|
||||
}
|
||||
}
|
||||
return extraEnvs
|
||||
}
|
||||
|
||||
const (
|
||||
signaturePreambleSigV4 = "AWS4-HMAC-SHA256"
|
||||
signaturePreambleSigV4A = "AWS4-ECDSA-P256-SHA256"
|
||||
)
|
||||
|
||||
// SigV4Signature represents a parsed sigv4 or sigv4a signature.
|
||||
type SigV4Signature struct {
|
||||
Preamble string // e.g. AWS4-HMAC-SHA256, AWS4-ECDSA-P256-SHA256
|
||||
SigningName string // generally the service name e.g. "s3"
|
||||
SigningRegion string // for sigv4a this is the region-set header as-is
|
||||
SignedHeaders []string // list of signed headers
|
||||
Signature string // calculated signature
|
||||
}
|
||||
|
||||
// ParseSigV4Signature deconstructs a sigv4 or sigv4a signature from a set of
|
||||
// request headers.
|
||||
func ParseSigV4Signature(header http.Header) *SigV4Signature {
|
||||
auth := header.Get("Authorization")
|
||||
|
||||
preamble, after, _ := strings.Cut(auth, " ")
|
||||
credential, after, _ := strings.Cut(after, ", ")
|
||||
signedHeaders, signature, _ := strings.Cut(after, ", ")
|
||||
|
||||
credentialParts := strings.Split(credential, "/")
|
||||
|
||||
// sigv4 : AccessKeyID/DateString/SigningRegion/SigningName/SignatureID
|
||||
// sigv4a : AccessKeyID/DateString/SigningName/SignatureID, region set on
|
||||
// header
|
||||
var signingName, signingRegion string
|
||||
if preamble == signaturePreambleSigV4 {
|
||||
signingName = credentialParts[3]
|
||||
signingRegion = credentialParts[2]
|
||||
} else if preamble == signaturePreambleSigV4A {
|
||||
signingName = credentialParts[2]
|
||||
signingRegion = header.Get("X-Amz-Region-Set")
|
||||
}
|
||||
|
||||
return &SigV4Signature{
|
||||
Preamble: preamble,
|
||||
SigningName: signingName,
|
||||
SigningRegion: signingRegion,
|
||||
SignedHeaders: strings.Split(signedHeaders, ";"),
|
||||
Signature: signature,
|
||||
}
|
||||
}
|
||||
75
aws/internal/awstesting/util_test.go
Normal file
75
aws/internal/awstesting/util_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package awstesting_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/versity/versitygw/aws/internal/awstesting"
|
||||
)
|
||||
|
||||
func TestReadCloserClose(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 1}
|
||||
err := rc.Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
if !rc.Closed {
|
||||
t.Errorf("expect closed, was not")
|
||||
}
|
||||
if e, a := rc.Size, 1; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCloserRead(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 5}
|
||||
b := make([]byte, 2)
|
||||
|
||||
n, err := rc.Read(b)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
if e, a := n, 2; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if rc.Closed {
|
||||
t.Errorf("expect not to be closed")
|
||||
}
|
||||
if e, a := rc.Size, 3; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
err = rc.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expect nil, got %v", err)
|
||||
}
|
||||
n, err = rc.Read(b)
|
||||
if e, a := err, io.EOF; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := n, 0; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCloserReadAll(t *testing.T) {
|
||||
rc := awstesting.ReadCloser{Size: 5}
|
||||
b := make([]byte, 5)
|
||||
|
||||
n, err := rc.Read(b)
|
||||
|
||||
if e, a := err, io.EOF; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := n, 5; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if rc.Closed {
|
||||
t.Errorf("expect not to be closed")
|
||||
}
|
||||
if e, a := rc.Size, 0; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
9
aws/internal/sdk/interfaces.go
Normal file
9
aws/internal/sdk/interfaces.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package sdk
|
||||
|
||||
// Invalidator provides access to a type's invalidate method to make it
|
||||
// invalidate it cache.
|
||||
//
|
||||
// e.g aws.SafeCredentialsProvider's Invalidate method.
|
||||
type Invalidator interface {
|
||||
Invalidate()
|
||||
}
|
||||
74
aws/internal/sdk/time.go
Normal file
74
aws/internal/sdk/time.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
NowTime = time.Now
|
||||
Sleep = time.Sleep
|
||||
SleepWithContext = sleepWithContext
|
||||
}
|
||||
|
||||
// NowTime is a value for getting the current time. This value can be overridden
|
||||
// for testing mocking out current time.
|
||||
var NowTime func() time.Time
|
||||
|
||||
// Sleep is a value for sleeping for a duration. This value can be overridden
|
||||
// for testing and mocking out sleep duration.
|
||||
var Sleep func(time.Duration)
|
||||
|
||||
// SleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the Context's
|
||||
// error will be returned.
|
||||
//
|
||||
// This value can be overridden for testing and mocking out sleep duration.
|
||||
var SleepWithContext func(context.Context, time.Duration) error
|
||||
|
||||
// sleepWithContext will wait for the timer duration to expire, or the context
|
||||
// is canceled. Which ever happens first. If the context is canceled the
|
||||
// Context's error will be returned.
|
||||
func sleepWithContext(ctx context.Context, dur time.Duration) error {
|
||||
t := time.NewTimer(dur)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-t.C:
|
||||
break
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// noOpSleepWithContext does nothing, returns immediately.
|
||||
func noOpSleepWithContext(context.Context, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func noOpSleep(time.Duration) {}
|
||||
|
||||
// TestingUseNopSleep is a utility for disabling sleep across the SDK for
|
||||
// testing.
|
||||
func TestingUseNopSleep() func() {
|
||||
SleepWithContext = noOpSleepWithContext
|
||||
Sleep = noOpSleep
|
||||
|
||||
return func() {
|
||||
SleepWithContext = sleepWithContext
|
||||
Sleep = time.Sleep
|
||||
}
|
||||
}
|
||||
|
||||
// TestingUseReferenceTime is a utility for swapping the time function across the SDK to return a specific reference time
|
||||
// for testing purposes.
|
||||
func TestingUseReferenceTime(referenceTime time.Time) func() {
|
||||
NowTime = func() time.Time {
|
||||
return referenceTime
|
||||
}
|
||||
return func() {
|
||||
NowTime = time.Now
|
||||
}
|
||||
}
|
||||
32
aws/internal/sdk/time_test.go
Normal file
32
aws/internal/sdk/time_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSleepWithContext(t *testing.T) {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
defer cancelFn()
|
||||
|
||||
err := sleepWithContext(ctx, 1*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Errorf("expect context to not be canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSleepWithContext_Canceled(t *testing.T) {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
cancelFn()
|
||||
|
||||
err := sleepWithContext(ctx, 10*time.Second)
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, did not get one")
|
||||
}
|
||||
|
||||
if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v error, got %v", e, a)
|
||||
}
|
||||
}
|
||||
11
aws/internal/strings/strings.go
Normal file
11
aws/internal/strings/strings.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package strings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
|
||||
// under Unicode case-folding.
|
||||
func HasPrefixFold(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
|
||||
}
|
||||
81
aws/internal/strings/strings_test.go
Normal file
81
aws/internal/strings/strings_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package strings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHasPrefixFold(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
prefix string
|
||||
}
|
||||
tests := map[string]struct {
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
"empty strings and prefix": {
|
||||
args: args{
|
||||
s: "",
|
||||
prefix: "",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"strings starts with prefix": {
|
||||
args: args{
|
||||
s: "some string",
|
||||
prefix: "some",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"prefix longer then string": {
|
||||
args: args{
|
||||
s: "some",
|
||||
prefix: "some string",
|
||||
},
|
||||
},
|
||||
"equal length string and prefix": {
|
||||
args: args{
|
||||
s: "short string",
|
||||
prefix: "short string",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"different cases": {
|
||||
args: args{
|
||||
s: "ShOrT StRING",
|
||||
prefix: "short",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"empty prefix not empty string": {
|
||||
args: args{
|
||||
s: "ShOrT StRING",
|
||||
prefix: "",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
"mixed-case prefixes": {
|
||||
args: args{
|
||||
s: "SoMe String",
|
||||
prefix: "sOme",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if got := HasPrefixFold(tt.args.s, tt.args.prefix); got != tt.want {
|
||||
t.Errorf("HasPrefixFold() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHasPrefixFold(b *testing.B) {
|
||||
HasPrefixFold("SoME string", "sOmE")
|
||||
}
|
||||
|
||||
func BenchmarkHasPrefix(b *testing.B) {
|
||||
strings.HasPrefix(strings.ToLower("SoME string"), strings.ToLower("sOmE"))
|
||||
}
|
||||
115
aws/signer/internal/v4/cache.go
Normal file
115
aws/signer/internal/v4/cache.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
func lookupKey(service, region string) string {
|
||||
var s strings.Builder
|
||||
s.Grow(len(region) + len(service) + 3)
|
||||
s.WriteString(region)
|
||||
s.WriteRune('/')
|
||||
s.WriteString(service)
|
||||
return s.String()
|
||||
}
|
||||
|
||||
type derivedKey struct {
|
||||
AccessKey string
|
||||
Date time.Time
|
||||
Credential []byte
|
||||
}
|
||||
|
||||
type derivedKeyCache struct {
|
||||
values map[string]derivedKey
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func newDerivedKeyCache() derivedKeyCache {
|
||||
return derivedKeyCache{
|
||||
values: make(map[string]derivedKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *derivedKeyCache) Get(credentials aws.Credentials, service, region string, signingTime SigningTime) []byte {
|
||||
key := lookupKey(service, region)
|
||||
s.mutex.RLock()
|
||||
if cred, ok := s.get(key, credentials, signingTime.Time); ok {
|
||||
s.mutex.RUnlock()
|
||||
return cred
|
||||
}
|
||||
s.mutex.RUnlock()
|
||||
|
||||
s.mutex.Lock()
|
||||
if cred, ok := s.get(key, credentials, signingTime.Time); ok {
|
||||
s.mutex.Unlock()
|
||||
return cred
|
||||
}
|
||||
cred := deriveKey(credentials.SecretAccessKey, service, region, signingTime)
|
||||
entry := derivedKey{
|
||||
AccessKey: credentials.AccessKeyID,
|
||||
Date: signingTime.Time,
|
||||
Credential: cred,
|
||||
}
|
||||
s.values[key] = entry
|
||||
s.mutex.Unlock()
|
||||
|
||||
return cred
|
||||
}
|
||||
|
||||
func (s *derivedKeyCache) get(key string, credentials aws.Credentials, signingTime time.Time) ([]byte, bool) {
|
||||
cacheEntry, ok := s.retrieveFromCache(key)
|
||||
if ok && cacheEntry.AccessKey == credentials.AccessKeyID && isSameDay(signingTime, cacheEntry.Date) {
|
||||
return cacheEntry.Credential, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *derivedKeyCache) retrieveFromCache(key string) (derivedKey, bool) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, true
|
||||
}
|
||||
return derivedKey{}, false
|
||||
}
|
||||
|
||||
// SigningKeyDeriver derives a signing key from a set of credentials
|
||||
type SigningKeyDeriver struct {
|
||||
cache derivedKeyCache
|
||||
}
|
||||
|
||||
// NewSigningKeyDeriver returns a new SigningKeyDeriver
|
||||
func NewSigningKeyDeriver() *SigningKeyDeriver {
|
||||
return &SigningKeyDeriver{
|
||||
cache: newDerivedKeyCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveKey returns a derived signing key from the given credentials to be used with SigV4 signing.
|
||||
func (k *SigningKeyDeriver) DeriveKey(credential aws.Credentials, service, region string, signingTime SigningTime) []byte {
|
||||
return k.cache.Get(credential, service, region, signingTime)
|
||||
}
|
||||
|
||||
func deriveKey(secret, service, region string, t SigningTime) []byte {
|
||||
hmacDate := HMACSHA256([]byte("AWS4"+secret), []byte(t.ShortTimeFormat()))
|
||||
hmacRegion := HMACSHA256(hmacDate, []byte(region))
|
||||
hmacService := HMACSHA256(hmacRegion, []byte(service))
|
||||
return HMACSHA256(hmacService, []byte("aws4_request"))
|
||||
}
|
||||
|
||||
func isSameDay(x, y time.Time) bool {
|
||||
xYear, xMonth, xDay := x.Date()
|
||||
yYear, yMonth, yDay := y.Date()
|
||||
|
||||
if xYear != yYear {
|
||||
return false
|
||||
}
|
||||
|
||||
if xMonth != yMonth {
|
||||
return false
|
||||
}
|
||||
|
||||
return xDay == yDay
|
||||
}
|
||||
40
aws/signer/internal/v4/const.go
Normal file
40
aws/signer/internal/v4/const.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package v4
|
||||
|
||||
// Signature Version 4 (SigV4) Constants
|
||||
const (
|
||||
// EmptyStringSHA256 is the hex encoded sha256 value of an empty string
|
||||
EmptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||
|
||||
// UnsignedPayload indicates that the request payload body is unsigned
|
||||
UnsignedPayload = "UNSIGNED-PAYLOAD"
|
||||
|
||||
// AmzAlgorithmKey indicates the signing algorithm
|
||||
AmzAlgorithmKey = "X-Amz-Algorithm"
|
||||
|
||||
// AmzSecurityTokenKey indicates the security token to be used with temporary credentials
|
||||
AmzSecurityTokenKey = "X-Amz-Security-Token"
|
||||
|
||||
// AmzDateKey is the UTC timestamp for the request in the format YYYYMMDD'T'HHMMSS'Z'
|
||||
AmzDateKey = "X-Amz-Date"
|
||||
|
||||
// AmzCredentialKey is the access key ID and credential scope
|
||||
AmzCredentialKey = "X-Amz-Credential"
|
||||
|
||||
// AmzSignedHeadersKey is the set of headers signed for the request
|
||||
AmzSignedHeadersKey = "X-Amz-SignedHeaders"
|
||||
|
||||
// AmzSignatureKey is the query parameter to store the SigV4 signature
|
||||
AmzSignatureKey = "X-Amz-Signature"
|
||||
|
||||
// TimeFormat is the time format to be used in the X-Amz-Date header or query parameter
|
||||
TimeFormat = "20060102T150405Z"
|
||||
|
||||
// ShortTimeFormat is the shorten time format used in the credential scope
|
||||
ShortTimeFormat = "20060102"
|
||||
|
||||
// ContentSHAKey is the SHA256 of request body
|
||||
ContentSHAKey = "X-Amz-Content-Sha256"
|
||||
|
||||
// StreamingEventsPayload indicates that the request payload body is a signed event stream.
|
||||
StreamingEventsPayload = "STREAMING-AWS4-HMAC-SHA256-EVENTS"
|
||||
)
|
||||
82
aws/signer/internal/v4/header_rules.go
Normal file
82
aws/signer/internal/v4/header_rules.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
sdkstrings "github.com/versity/versitygw/aws/internal/strings"
|
||||
)
|
||||
|
||||
// Rules houses a set of Rule needed for validation of a
|
||||
// string value
|
||||
type Rules []Rule
|
||||
|
||||
// Rule interface allows for more flexible rules and just simply
|
||||
// checks whether or not a value adheres to that Rule
|
||||
type Rule interface {
|
||||
IsValid(value string) bool
|
||||
}
|
||||
|
||||
// IsValid will iterate through all rules and see if any rules
|
||||
// apply to the value and supports nested rules
|
||||
func (r Rules) IsValid(value string) bool {
|
||||
for _, rule := range r {
|
||||
if rule.IsValid(value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MapRule generic Rule for maps
|
||||
type MapRule map[string]struct{}
|
||||
|
||||
// IsValid for the map Rule satisfies whether it exists in the map
|
||||
func (m MapRule) IsValid(value string) bool {
|
||||
_, ok := m[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
// AllowList is a generic Rule for include listing
|
||||
type AllowList struct {
|
||||
Rule
|
||||
}
|
||||
|
||||
// IsValid for AllowList checks if the value is within the AllowList
|
||||
func (w AllowList) IsValid(value string) bool {
|
||||
return w.Rule.IsValid(value)
|
||||
}
|
||||
|
||||
// ExcludeList is a generic Rule for exclude listing
|
||||
type ExcludeList struct {
|
||||
Rule
|
||||
}
|
||||
|
||||
// IsValid for AllowList checks if the value is within the AllowList
|
||||
func (b ExcludeList) IsValid(value string) bool {
|
||||
return !b.Rule.IsValid(value)
|
||||
}
|
||||
|
||||
// Patterns is a list of strings to match against
|
||||
type Patterns []string
|
||||
|
||||
// IsValid for Patterns checks each pattern and returns if a match has
|
||||
// been found
|
||||
func (p Patterns) IsValid(value string) bool {
|
||||
for _, pattern := range p {
|
||||
if sdkstrings.HasPrefixFold(value, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// InclusiveRules rules allow for rules to depend on one another
|
||||
type InclusiveRules []Rule
|
||||
|
||||
// IsValid will return true if all rules are true
|
||||
func (r InclusiveRules) IsValid(value string) bool {
|
||||
for _, rule := range r {
|
||||
if !rule.IsValid(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
71
aws/signer/internal/v4/headers.go
Normal file
71
aws/signer/internal/v4/headers.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package v4
|
||||
|
||||
// IgnoredHeaders is a list of headers that are ignored during signing
|
||||
var IgnoredHeaders = Rules{
|
||||
ExcludeList{
|
||||
MapRule{
|
||||
"Authorization": struct{}{},
|
||||
"User-Agent": struct{}{},
|
||||
"X-Amzn-Trace-Id": struct{}{},
|
||||
"Expect": struct{}{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// RequiredSignedHeaders is a allow list for Build canonical headers.
|
||||
var RequiredSignedHeaders = Rules{
|
||||
AllowList{
|
||||
MapRule{
|
||||
"Cache-Control": struct{}{},
|
||||
"Content-Disposition": struct{}{},
|
||||
"Content-Encoding": struct{}{},
|
||||
"Content-Language": struct{}{},
|
||||
"Content-Md5": struct{}{},
|
||||
"Content-Type": struct{}{},
|
||||
"Expires": struct{}{},
|
||||
"If-Match": struct{}{},
|
||||
"If-Modified-Since": struct{}{},
|
||||
"If-None-Match": struct{}{},
|
||||
"If-Unmodified-Since": struct{}{},
|
||||
"Range": struct{}{},
|
||||
"X-Amz-Acl": struct{}{},
|
||||
"X-Amz-Copy-Source": struct{}{},
|
||||
"X-Amz-Copy-Source-If-Match": struct{}{},
|
||||
"X-Amz-Copy-Source-If-Modified-Since": struct{}{},
|
||||
"X-Amz-Copy-Source-If-None-Match": struct{}{},
|
||||
"X-Amz-Copy-Source-If-Unmodified-Since": struct{}{},
|
||||
"X-Amz-Copy-Source-Range": struct{}{},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
|
||||
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
|
||||
"X-Amz-Expected-Bucket-Owner": struct{}{},
|
||||
"X-Amz-Grant-Full-control": struct{}{},
|
||||
"X-Amz-Grant-Read": struct{}{},
|
||||
"X-Amz-Grant-Read-Acp": struct{}{},
|
||||
"X-Amz-Grant-Write": struct{}{},
|
||||
"X-Amz-Grant-Write-Acp": struct{}{},
|
||||
"X-Amz-Metadata-Directive": struct{}{},
|
||||
"X-Amz-Mfa": struct{}{},
|
||||
"X-Amz-Request-Payer": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Context": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Algorithm": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Key": struct{}{},
|
||||
"X-Amz-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
|
||||
"X-Amz-Storage-Class": struct{}{},
|
||||
"X-Amz-Website-Redirect-Location": struct{}{},
|
||||
"X-Amz-Content-Sha256": struct{}{},
|
||||
"X-Amz-Tagging": struct{}{},
|
||||
},
|
||||
},
|
||||
Patterns{"X-Amz-Object-Lock-"},
|
||||
Patterns{"X-Amz-Meta-"},
|
||||
}
|
||||
|
||||
// AllowedQueryHoisting is a allowed list for Build query headers. The boolean value
|
||||
// represents whether or not it is a pattern.
|
||||
var AllowedQueryHoisting = InclusiveRules{
|
||||
ExcludeList{RequiredSignedHeaders},
|
||||
Patterns{"X-Amz-"},
|
||||
}
|
||||
63
aws/signer/internal/v4/headers_test.go
Normal file
63
aws/signer/internal/v4/headers_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package v4
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAllowedQueryHoisting(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
Header string
|
||||
ExpectHoist bool
|
||||
}{
|
||||
"object-lock": {
|
||||
Header: "X-Amz-Object-Lock-Mode",
|
||||
ExpectHoist: false,
|
||||
},
|
||||
"s3 metadata": {
|
||||
Header: "X-Amz-Meta-SomeName",
|
||||
ExpectHoist: false,
|
||||
},
|
||||
"another header": {
|
||||
Header: "X-Amz-SomeOtherHeader",
|
||||
ExpectHoist: true,
|
||||
},
|
||||
"non X-AMZ header": {
|
||||
Header: "X-SomeOtherHeader",
|
||||
ExpectHoist: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if e, a := c.ExpectHoist, AllowedQueryHoisting.IsValid(c.Header); e != a {
|
||||
t.Errorf("expect hoist %v, was %v", e, a)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIgnoredHeaders(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
Header string
|
||||
ExpectIgnored bool
|
||||
}{
|
||||
"expect": {
|
||||
Header: "Expect",
|
||||
ExpectIgnored: true,
|
||||
},
|
||||
"authorization": {
|
||||
Header: "Authorization",
|
||||
ExpectIgnored: true,
|
||||
},
|
||||
"X-AMZ header": {
|
||||
Header: "X-Amz-Content-Sha256",
|
||||
ExpectIgnored: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if e, a := c.ExpectIgnored, IgnoredHeaders.IsValid(c.Header); e == a {
|
||||
t.Errorf("expect ignored %v, was %v", e, a)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
aws/signer/internal/v4/hmac.go
Normal file
13
aws/signer/internal/v4/hmac.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
)
|
||||
|
||||
// HMACSHA256 computes a HMAC-SHA256 of data given the provided key.
|
||||
func HMACSHA256(key []byte, data []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(data)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
75
aws/signer/internal/v4/host.go
Normal file
75
aws/signer/internal/v4/host.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SanitizeHostForHeader removes default port from host and updates request.Host
|
||||
func SanitizeHostForHeader(r *http.Request) {
|
||||
host := getHost(r)
|
||||
port := portOnly(host)
|
||||
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||
r.Host = stripPort(host)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns host from request
|
||||
func getHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
}
|
||||
|
||||
return r.URL.Host
|
||||
}
|
||||
|
||||
// Hostname returns u.Host, without any port number.
|
||||
//
|
||||
// If Host is an IPv6 literal with a port number, Hostname returns the
|
||||
// IPv6 literal without the square brackets. IPv6 literals may include
|
||||
// a zone identifier.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
|
||||
// Port returns the port part of u.Host, without the leading colon.
|
||||
// If u.Host doesn't contain a port, Port returns an empty string.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func portOnly(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return ""
|
||||
}
|
||||
if i := strings.Index(hostport, "]:"); i != -1 {
|
||||
return hostport[i+len("]:"):]
|
||||
}
|
||||
if strings.Contains(hostport, "]") {
|
||||
return ""
|
||||
}
|
||||
return hostport[colon+len(":"):]
|
||||
}
|
||||
|
||||
// Returns true if the specified URI is using the standard port
|
||||
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
|
||||
func isDefaultPort(scheme, port string) bool {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerCaseScheme := strings.ToLower(scheme)
|
||||
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
13
aws/signer/internal/v4/scope.go
Normal file
13
aws/signer/internal/v4/scope.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package v4
|
||||
|
||||
import "strings"
|
||||
|
||||
// BuildCredentialScope builds the Signature Version 4 (SigV4) signing scope
|
||||
func BuildCredentialScope(signingTime SigningTime, region, service string) string {
|
||||
return strings.Join([]string{
|
||||
signingTime.ShortTimeFormat(),
|
||||
region,
|
||||
service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
}
|
||||
36
aws/signer/internal/v4/time.go
Normal file
36
aws/signer/internal/v4/time.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package v4
|
||||
|
||||
import "time"
|
||||
|
||||
// SigningTime provides a wrapper around a time.Time which provides cached values for SigV4 signing.
|
||||
type SigningTime struct {
|
||||
time.Time
|
||||
timeFormat string
|
||||
shortTimeFormat string
|
||||
}
|
||||
|
||||
// NewSigningTime creates a new SigningTime given a time.Time
|
||||
func NewSigningTime(t time.Time) SigningTime {
|
||||
return SigningTime{
|
||||
Time: t,
|
||||
}
|
||||
}
|
||||
|
||||
// TimeFormat provides a time formatted in the X-Amz-Date format.
|
||||
func (m *SigningTime) TimeFormat() string {
|
||||
return m.format(&m.timeFormat, TimeFormat)
|
||||
}
|
||||
|
||||
// ShortTimeFormat provides a time formatted of 20060102.
|
||||
func (m *SigningTime) ShortTimeFormat() string {
|
||||
return m.format(&m.shortTimeFormat, ShortTimeFormat)
|
||||
}
|
||||
|
||||
func (m *SigningTime) format(target *string, format string) string {
|
||||
if len(*target) > 0 {
|
||||
return *target
|
||||
}
|
||||
v := m.Time.Format(format)
|
||||
*target = v
|
||||
return v
|
||||
}
|
||||
80
aws/signer/internal/v4/util.go
Normal file
80
aws/signer/internal/v4/util.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const doubleSpace = " "
|
||||
|
||||
// StripExcessSpaces will rewrite the passed in slice's string values to not
|
||||
// contain multiple side-by-side spaces.
|
||||
func StripExcessSpaces(str string) string {
|
||||
var j, k, l, m, spaces int
|
||||
// Trim trailing spaces
|
||||
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
|
||||
}
|
||||
|
||||
// Trim leading spaces
|
||||
for k = 0; k < j && str[k] == ' '; k++ {
|
||||
}
|
||||
str = str[k : j+1]
|
||||
|
||||
// Strip multiple spaces.
|
||||
j = strings.Index(str, doubleSpace)
|
||||
if j < 0 {
|
||||
return str
|
||||
}
|
||||
|
||||
buf := []byte(str)
|
||||
for k, m, l = j, j, len(buf); k < l; k++ {
|
||||
if buf[k] == ' ' {
|
||||
if spaces == 0 {
|
||||
// First space.
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
spaces++
|
||||
} else {
|
||||
// End of multiple spaces.
|
||||
spaces = 0
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
}
|
||||
|
||||
return string(buf[:m])
|
||||
}
|
||||
|
||||
// GetURIPath returns the escaped URI component from the provided URL.
|
||||
func GetURIPath(u *url.URL) string {
|
||||
var uriPath string
|
||||
|
||||
if len(u.Opaque) > 0 {
|
||||
const schemeSep, pathSep, queryStart = "//", "/", "?"
|
||||
|
||||
opaque := u.Opaque
|
||||
// Cut off the query string if present.
|
||||
if idx := strings.Index(opaque, queryStart); idx >= 0 {
|
||||
opaque = opaque[:idx]
|
||||
}
|
||||
|
||||
// Cutout the scheme separator if present.
|
||||
if strings.Index(opaque, schemeSep) == 0 {
|
||||
opaque = opaque[len(schemeSep):]
|
||||
}
|
||||
|
||||
// capture URI path starting with first path separator.
|
||||
if idx := strings.Index(opaque, pathSep); idx >= 0 {
|
||||
uriPath = opaque[idx:]
|
||||
}
|
||||
} else {
|
||||
uriPath = u.EscapedPath()
|
||||
}
|
||||
|
||||
if len(uriPath) == 0 {
|
||||
uriPath = "/"
|
||||
}
|
||||
|
||||
return uriPath
|
||||
}
|
||||
158
aws/signer/internal/v4/util_test.go
Normal file
158
aws/signer/internal/v4/util_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func lazyURLParse(v string) func() (*url.URL, error) {
|
||||
return func() (*url.URL, error) {
|
||||
return url.Parse(v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetURIPath(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
getURL func() (*url.URL, error)
|
||||
expect string
|
||||
}{
|
||||
// Cases
|
||||
"with scheme": {
|
||||
getURL: lazyURLParse("https://localhost:9000"),
|
||||
expect: "/",
|
||||
},
|
||||
"no port, with scheme": {
|
||||
getURL: lazyURLParse("https://localhost"),
|
||||
expect: "/",
|
||||
},
|
||||
"without scheme": {
|
||||
getURL: lazyURLParse("localhost:9000"),
|
||||
expect: "/",
|
||||
},
|
||||
"without scheme, with path": {
|
||||
getURL: lazyURLParse("localhost:9000/abc123"),
|
||||
expect: "/abc123",
|
||||
},
|
||||
"without scheme, with separator": {
|
||||
getURL: lazyURLParse("//localhost:9000"),
|
||||
expect: "/",
|
||||
},
|
||||
"no port, without scheme, with separator": {
|
||||
getURL: lazyURLParse("//localhost"),
|
||||
expect: "/",
|
||||
},
|
||||
"without scheme, with separator, with path": {
|
||||
getURL: lazyURLParse("//localhost:9000/abc123"),
|
||||
expect: "/abc123",
|
||||
},
|
||||
"no port, without scheme, with separator, with path": {
|
||||
getURL: lazyURLParse("//localhost/abc123"),
|
||||
expect: "/abc123",
|
||||
},
|
||||
"opaque with query string": {
|
||||
getURL: lazyURLParse("localhost:9000/abc123?efg=456"),
|
||||
expect: "/abc123",
|
||||
},
|
||||
"failing test": {
|
||||
getURL: func() (*url.URL, error) {
|
||||
endpoint := "https://service.region.amazonaws.com"
|
||||
req, _ := http.NewRequest("POST", endpoint, nil)
|
||||
u := req.URL
|
||||
|
||||
u.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
|
||||
|
||||
query := u.Query()
|
||||
query.Set("some-query-key", "value")
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
return u, nil
|
||||
},
|
||||
expect: "/bucket/key-._~,!@#$%^&*()",
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
u, err := c.getURL()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get URL, %v", err)
|
||||
}
|
||||
|
||||
actual := GetURIPath(u)
|
||||
if e, a := c.expect, actual; e != a {
|
||||
t.Errorf("expect %v path, got %v", e, a)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripExcessHeaders(t *testing.T) {
|
||||
vals := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3 ",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2",
|
||||
"1 2",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
for i := 0; i < len(vals); i++ {
|
||||
r := StripExcessSpaces(vals[i])
|
||||
if e, a := expected[i], r; e != a {
|
||||
t.Errorf("%d, expect %v, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stripExcessSpaceCases = []string{
|
||||
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
|
||||
`123 321 123 321`,
|
||||
` 123 321 123 321 `,
|
||||
` 123 321 123 321 `,
|
||||
"123",
|
||||
"1 2 3",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
func BenchmarkStripExcessSpaces(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, v := range stripExcessSpaceCases {
|
||||
StripExcessSpaces(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
139
aws/signer/v4/functional_test.go
Normal file
139
aws/signer/v4/functional_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package v4_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
"github.com/versity/versitygw/aws/internal/awstesting/unit"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
var standaloneSignCases = []struct {
|
||||
OrigURI string
|
||||
OrigQuery string
|
||||
Region, Service, SubDomain string
|
||||
ExpSig string
|
||||
EscapedURI string
|
||||
}{
|
||||
{
|
||||
OrigURI: `/logs-*/_search`,
|
||||
OrigQuery: `pretty=true`,
|
||||
Region: "us-west-2", Service: "es", SubDomain: "hostname-clusterkey",
|
||||
EscapedURI: `/logs-%2A/_search`,
|
||||
ExpSig: `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-west-2/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=79d0760751907af16f64a537c1242416dacf51204a7dd5284492d15577973b91`,
|
||||
},
|
||||
}
|
||||
|
||||
func TestStandaloneSign_CustomURIEscape(t *testing.T) {
|
||||
var expectSig = `AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/es/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=6601e883cc6d23871fd6c2a394c5677ea2b8c82b04a6446786d64cd74f520967`
|
||||
|
||||
creds, err := unit.Config().Credentials.Retrieve(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
signer := v4.NewSigner(func(signer *v4.SignerOptions) {
|
||||
signer.DisableURIPathEscaping = true
|
||||
})
|
||||
|
||||
host := "https://subdomain.us-east-1.es.amazonaws.com"
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
req.URL.Path = `/log-*/_search`
|
||||
req.URL.Opaque = "//subdomain.us-east-1.es.amazonaws.com/log-%2A/_search"
|
||||
|
||||
err = signer.SignHTTP(context.Background(), creds, req, v4Internal.EmptyStringSHA256, "es", "us-east-1", time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
if e, a := expectSig, actual; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandaloneSign(t *testing.T) {
|
||||
creds, err := unit.Config().Credentials.Retrieve(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
signer := v4.NewSigner()
|
||||
|
||||
for _, c := range standaloneSignCases {
|
||||
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
|
||||
c.SubDomain, c.Region, c.Service)
|
||||
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
// URL.EscapedPath() will be used by the signer to get the
|
||||
// escaped form of the request's URI path.
|
||||
req.URL.Path = c.OrigURI
|
||||
req.URL.RawQuery = c.OrigQuery
|
||||
|
||||
err = signer.SignHTTP(context.Background(), creds, req, v4Internal.EmptyStringSHA256, c.Service, c.Region, time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
if e, a := c.ExpSig, actual; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.OrigURI, req.URL.Path; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandaloneSign_RawPath(t *testing.T) {
|
||||
creds, err := unit.Config().Credentials.Retrieve(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
signer := v4.NewSigner()
|
||||
|
||||
for _, c := range standaloneSignCases {
|
||||
host := fmt.Sprintf("https://%s.%s.%s.amazonaws.com",
|
||||
c.SubDomain, c.Region, c.Service)
|
||||
|
||||
req, err := http.NewRequest("GET", host, nil)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
// URL.EscapedPath() will be used by the signer to get the
|
||||
// escaped form of the request's URI path.
|
||||
req.URL.Path = c.OrigURI
|
||||
req.URL.RawPath = c.EscapedURI
|
||||
req.URL.RawQuery = c.OrigQuery
|
||||
|
||||
err = signer.SignHTTP(context.Background(), creds, req, v4Internal.EmptyStringSHA256, c.Service, c.Region, time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but received %v", err)
|
||||
}
|
||||
|
||||
actual := req.Header.Get("Authorization")
|
||||
if e, a := c.ExpSig, actual; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.OrigURI, req.URL.Path; e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
if e, a := c.EscapedURI, req.URL.EscapedPath(); e != a {
|
||||
t.Errorf("expected %v, but recieved %v", e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
443
aws/signer/v4/middleware.go
Normal file
443
aws/signer/v4/middleware.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
internalauth "github.com/versity/versitygw/aws/internal/auth"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
const computePayloadHashMiddlewareID = "ComputePayloadHash"
|
||||
|
||||
// HashComputationError indicates an error occurred while computing the signing hash
|
||||
type HashComputationError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error is the error message
|
||||
func (e *HashComputationError) Error() string {
|
||||
return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error if one is set
|
||||
func (e *HashComputationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// SigningError indicates an error condition occurred while performing SigV4 signing
|
||||
type SigningError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SigningError) Error() string {
|
||||
return fmt.Sprintf("failed to sign request: %v", e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error cause
|
||||
func (e *SigningError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// UseDynamicPayloadSigningMiddleware swaps the compute payload sha256 middleware with a resolver middleware that
|
||||
// switches between unsigned and signed payload based on TLS state for request.
|
||||
// This middleware should not be used for AWS APIs that do not support unsigned payload signing auth.
|
||||
// By default, SDK uses this middleware for known AWS APIs that support such TLS based auth selection .
|
||||
//
|
||||
// Usage example -
|
||||
// S3 PutObject API allows unsigned payload signing auth usage when TLS is enabled, and uses this middleware to
|
||||
// dynamically switch between unsigned and signed payload based on TLS state for request.
|
||||
func UseDynamicPayloadSigningMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Swap(computePayloadHashMiddlewareID, &dynamicPayloadSigningMiddleware{})
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamicPayloadSigningMiddleware dynamically resolves the middleware that computes and set payload sha256 middleware.
|
||||
type dynamicPayloadSigningMiddleware struct {
|
||||
}
|
||||
|
||||
// ID returns the resolver identifier
|
||||
func (m *dynamicPayloadSigningMiddleware) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize delegates SHA256 computation according to whether the request
|
||||
// is TLS-enabled.
|
||||
func (m *dynamicPayloadSigningMiddleware) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
|
||||
}
|
||||
|
||||
if req.IsHTTPS() {
|
||||
return (&UnsignedPayload{}).HandleFinalize(ctx, in, next)
|
||||
}
|
||||
return (&ComputePayloadSHA256{}).HandleFinalize(ctx, in, next)
|
||||
}
|
||||
|
||||
// UnsignedPayload sets the SigV4 request payload hash to unsigned.
|
||||
//
|
||||
// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
|
||||
// stored in the context. (e.g. application pre-computed SHA256 before making
|
||||
// API call).
|
||||
//
|
||||
// This middleware does not check the X-Amz-Content-Sha256 header, if that
|
||||
// header is serialized a middleware must translate it into the context.
|
||||
type UnsignedPayload struct{}
|
||||
|
||||
// AddUnsignedPayloadMiddleware adds unsignedPayload to the operation
|
||||
// middleware stack
|
||||
func AddUnsignedPayloadMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&UnsignedPayload{}, "ResolveEndpointV2", middleware.After)
|
||||
}
|
||||
|
||||
// ID returns the unsignedPayload identifier
|
||||
func (m *UnsignedPayload) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize sets the payload hash magic value to the unsigned sentinel.
|
||||
func (m *UnsignedPayload) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if GetPayloadHash(ctx) == "" {
|
||||
ctx = SetPayloadHash(ctx, v4Internal.UnsignedPayload)
|
||||
}
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// ComputePayloadSHA256 computes SHA256 payload hash to sign.
|
||||
//
|
||||
// Will not set the Unsigned Payload magic SHA value, if a SHA has already been
|
||||
// stored in the context. (e.g. application pre-computed SHA256 before making
|
||||
// API call).
|
||||
//
|
||||
// This middleware does not check the X-Amz-Content-Sha256 header, if that
|
||||
// header is serialized a middleware must translate it into the context.
|
||||
type ComputePayloadSHA256 struct{}
|
||||
|
||||
// AddComputePayloadSHA256Middleware adds computePayloadSHA256 to the
|
||||
// operation middleware stack
|
||||
func AddComputePayloadSHA256Middleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&ComputePayloadSHA256{}, "ResolveEndpointV2", middleware.After)
|
||||
}
|
||||
|
||||
// RemoveComputePayloadSHA256Middleware removes computePayloadSHA256 from the
|
||||
// operation middleware stack
|
||||
func RemoveComputePayloadSHA256Middleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Remove(computePayloadHashMiddlewareID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ID is the middleware name
|
||||
func (m *ComputePayloadSHA256) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize computes the payload hash for the request, storing it to the
|
||||
// context. This is a no-op if a caller has previously set that value.
|
||||
func (m *ComputePayloadSHA256) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if GetPayloadHash(ctx) != "" {
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
|
||||
}
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
if stream := req.GetStream(); stream != nil {
|
||||
_, err = io.Copy(hash, stream)
|
||||
if err != nil {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("failed to compute payload hash, %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
if err := req.RewindStream(); err != nil {
|
||||
return out, metadata, &HashComputationError{
|
||||
Err: fmt.Errorf("failed to seek body to start, %w", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// SwapComputePayloadSHA256ForUnsignedPayloadMiddleware replaces the
|
||||
// ComputePayloadSHA256 middleware with the UnsignedPayload middleware.
|
||||
//
|
||||
// Use this to disable computing the Payload SHA256 checksum and instead use
|
||||
// UNSIGNED-PAYLOAD for the SHA256 value.
|
||||
func SwapComputePayloadSHA256ForUnsignedPayloadMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Swap(computePayloadHashMiddlewareID, &UnsignedPayload{})
|
||||
return err
|
||||
}
|
||||
|
||||
// ContentSHA256Header sets the X-Amz-Content-Sha256 header value to
|
||||
// the Payload hash stored in the context.
|
||||
type ContentSHA256Header struct{}
|
||||
|
||||
// AddContentSHA256HeaderMiddleware adds ContentSHA256Header to the
|
||||
// operation middleware stack
|
||||
func AddContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Insert(&ContentSHA256Header{}, computePayloadHashMiddlewareID, middleware.After)
|
||||
}
|
||||
|
||||
// RemoveContentSHA256HeaderMiddleware removes contentSHA256Header middleware
|
||||
// from the operation middleware stack
|
||||
func RemoveContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
|
||||
_, err := stack.Finalize.Remove((*ContentSHA256Header)(nil).ID())
|
||||
return err
|
||||
}
|
||||
|
||||
// ID returns the ContentSHA256HeaderMiddleware identifier
|
||||
func (m *ContentSHA256Header) ID() string {
|
||||
return "SigV4ContentSHA256Header"
|
||||
}
|
||||
|
||||
// HandleFinalize sets the X-Amz-Content-Sha256 header value to the Payload hash
|
||||
// stored in the context.
|
||||
func (m *ContentSHA256Header) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
|
||||
}
|
||||
|
||||
req.Header.Set(v4Internal.ContentSHAKey, GetPayloadHash(ctx))
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// SignHTTPRequestMiddlewareOptions is the configuration options for
|
||||
// [SignHTTPRequestMiddleware].
|
||||
//
|
||||
// Deprecated: [SignHTTPRequestMiddleware] is deprecated.
|
||||
type SignHTTPRequestMiddlewareOptions struct {
|
||||
CredentialsProvider aws.CredentialsProvider
|
||||
Signer HTTPSigner
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
// SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4
|
||||
// HTTP Signing.
|
||||
//
|
||||
// Deprecated: AWS service clients no longer use this middleware. Signing as an
|
||||
// SDK operation is now performed through an internal per-service middleware
|
||||
// which opaquely selects and uses the signer from the resolved auth scheme.
|
||||
type SignHTTPRequestMiddleware struct {
|
||||
credentialsProvider aws.CredentialsProvider
|
||||
signer HTTPSigner
|
||||
logSigning bool
|
||||
}
|
||||
|
||||
// NewSignHTTPRequestMiddleware constructs a [SignHTTPRequestMiddleware] using
|
||||
// the given [Signer] for signing requests.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
|
||||
return &SignHTTPRequestMiddleware{
|
||||
credentialsProvider: options.CredentialsProvider,
|
||||
signer: options.Signer,
|
||||
logSigning: options.LogSigning,
|
||||
}
|
||||
}
|
||||
|
||||
// ID is the SignHTTPRequestMiddleware identifier.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func (s *SignHTTPRequestMiddleware) ID() string {
|
||||
return "Signing"
|
||||
}
|
||||
|
||||
// HandleFinalize will take the provided input and sign the request using the
|
||||
// SigV4 authentication scheme.
|
||||
//
|
||||
// Deprecated: SignHTTPRequestMiddleware is deprecated.
|
||||
func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
if !haveCredentialProvider(s.credentialsProvider) {
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
|
||||
}
|
||||
|
||||
signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
|
||||
payloadHash := GetPayloadHash(ctx)
|
||||
if len(payloadHash) == 0 {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
|
||||
}
|
||||
|
||||
mctx := metrics.Context(ctx)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.CredentialFetchStartTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
credentials, err := s.credentialsProvider.Retrieve(ctx)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.CredentialFetchEndTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
|
||||
}
|
||||
|
||||
signerOptions := []func(o *SignerOptions){
|
||||
func(o *SignerOptions) {
|
||||
o.Logger = middleware.GetLogger(ctx)
|
||||
o.LogSigning = s.logSigning
|
||||
},
|
||||
}
|
||||
|
||||
// existing DisableURIPathEscaping is equivalent in purpose
|
||||
// to authentication scheme property DisableDoubleEncoding
|
||||
disableDoubleEncoding, overridden := internalauth.GetDisableDoubleEncoding(ctx)
|
||||
if overridden {
|
||||
signerOptions = append(signerOptions, func(o *SignerOptions) {
|
||||
o.DisableURIPathEscaping = disableDoubleEncoding
|
||||
})
|
||||
}
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.SignStartTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(), signerOptions...)
|
||||
|
||||
if mctx != nil {
|
||||
if attempt, err := mctx.Data().LatestAttempt(); err == nil {
|
||||
attempt.SignEndTime = sdk.NowTime()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
|
||||
}
|
||||
|
||||
ctx = awsmiddleware.SetSigningCredentials(ctx, credentials)
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// StreamingEventsPayload signs input event stream messages.
|
||||
type StreamingEventsPayload struct{}
|
||||
|
||||
// AddStreamingEventsPayload adds the streamingEventsPayload middleware to the stack.
|
||||
func AddStreamingEventsPayload(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Add(&StreamingEventsPayload{}, middleware.Before)
|
||||
}
|
||||
|
||||
// ID identifies the middleware.
|
||||
func (s *StreamingEventsPayload) ID() string {
|
||||
return computePayloadHashMiddlewareID
|
||||
}
|
||||
|
||||
// HandleFinalize marks the input stream to be signed with SigV4.
|
||||
func (s *StreamingEventsPayload) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
contentSHA := GetPayloadHash(ctx)
|
||||
if len(contentSHA) == 0 {
|
||||
contentSHA = v4Internal.StreamingEventsPayload
|
||||
}
|
||||
|
||||
ctx = SetPayloadHash(ctx, contentSHA)
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
// GetSignedRequestSignature attempts to extract the signature of the request.
|
||||
// Returning an error if the request is unsigned, or unable to extract the
|
||||
// signature.
|
||||
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
|
||||
const authHeaderSignatureElem = "Signature="
|
||||
|
||||
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
|
||||
ps := strings.Split(auth, ", ")
|
||||
for _, p := range ps {
|
||||
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
|
||||
sig := p[len(authHeaderSignatureElem):]
|
||||
if len(sig) == 0 {
|
||||
return nil, fmt.Errorf("invalid request signature authorization header")
|
||||
}
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("request not signed")
|
||||
}
|
||||
|
||||
func haveCredentialProvider(p aws.CredentialsProvider) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !aws.IsCredentialsProvider(p, (*aws.AnonymousCredentials)(nil))
|
||||
}
|
||||
|
||||
type payloadHashKey struct{}
|
||||
|
||||
// GetPayloadHash retrieves the payload hash to use for signing
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetPayloadHash(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, payloadHashKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// SetPayloadHash sets the payload hash to be used for signing the request
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetPayloadHash(ctx context.Context, hash string) context.Context {
|
||||
return middleware.WithStackValue(ctx, payloadHashKey{}, hash)
|
||||
}
|
||||
415
aws/signer/v4/middleware_test.go
Normal file
415
aws/signer/v4/middleware_test.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/versity/versitygw/aws/internal/awstesting/unit"
|
||||
)
|
||||
|
||||
func TestComputePayloadHashMiddleware(t *testing.T) {
|
||||
cases := []struct {
|
||||
content io.Reader
|
||||
expectedHash string
|
||||
expectedErr interface{}
|
||||
}{
|
||||
0: {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
expectedHash: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56",
|
||||
},
|
||||
1: {
|
||||
content: func() io.Reader {
|
||||
return &nonSeeker{}
|
||||
}(),
|
||||
expectedErr: &HashComputationError{},
|
||||
},
|
||||
2: {
|
||||
content: func() io.Reader {
|
||||
return &semiSeekable{}
|
||||
}(),
|
||||
expectedErr: &HashComputationError{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range cases {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
c := &ComputePayloadSHA256{}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
value := GetPayloadHash(ctx)
|
||||
if len(value) == 0 {
|
||||
t.Fatalf("expected payload hash value to be on context")
|
||||
}
|
||||
if e, a := tt.expectedHash, value; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
stream, err := smithyhttp.NewStackRequest().(*smithyhttp.Request).SetStream(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
_, _, err = c.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: stream}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type httpSignerFunc func(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error
|
||||
|
||||
func (f httpSignerFunc) SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
|
||||
return f(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
|
||||
}
|
||||
|
||||
func TestSignHTTPRequestMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
creds aws.CredentialsProvider
|
||||
hash string
|
||||
logSigning bool
|
||||
expectedErr interface{}
|
||||
}{
|
||||
"success": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "0123456789abcdef",
|
||||
},
|
||||
"error": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "",
|
||||
expectedErr: &SigningError{},
|
||||
},
|
||||
"anonymous creds": {
|
||||
creds: aws.AnonymousCredentials{},
|
||||
},
|
||||
"nil creds": {
|
||||
creds: nil,
|
||||
},
|
||||
"with log signing": {
|
||||
creds: unit.StubCredentialsProvider{},
|
||||
hash: "0123456789abcdef",
|
||||
logSigning: true,
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
signingName = "serviceId"
|
||||
signingRegion = "regionName"
|
||||
)
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c := &SignHTTPRequestMiddleware{
|
||||
credentialsProvider: tt.creds,
|
||||
signer: httpSignerFunc(
|
||||
func(ctx context.Context,
|
||||
credentials aws.Credentials, r *http.Request, payloadHash string,
|
||||
service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) error {
|
||||
var options SignerOptions
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
if options.Logger == nil {
|
||||
t.Errorf("expect logger, got none")
|
||||
}
|
||||
if options.LogSigning {
|
||||
options.Logger.Logf(logging.Debug, t.Name())
|
||||
}
|
||||
|
||||
expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
|
||||
if e, a := expectCreds, credentials; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := tt.hash, payloadHash; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingName, service; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingRegion, region; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
logSigning: tt.logSigning,
|
||||
}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
ctx := awsmiddleware.SetSigningRegion(
|
||||
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||||
signingRegion)
|
||||
|
||||
var loggerBuf bytes.Buffer
|
||||
logger := logging.NewStandardLogger(&loggerBuf)
|
||||
ctx = middleware.SetLogger(ctx, logger)
|
||||
|
||||
if len(tt.hash) != 0 {
|
||||
ctx = SetPayloadHash(ctx, tt.hash)
|
||||
}
|
||||
|
||||
_, _, err := c.HandleFinalize(ctx, middleware.FinalizeInput{
|
||||
Request: &smithyhttp.Request{Request: &http.Request{}},
|
||||
}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
|
||||
if tt.logSigning {
|
||||
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v logged in %v", e, a)
|
||||
}
|
||||
} else {
|
||||
if loggerBuf.Len() != 0 {
|
||||
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwapComputePayloadSHA256ForUnsignedPayloadMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
InitStep func(*middleware.Stack) error
|
||||
Mutator func(*middleware.Stack) error
|
||||
ExpectErr string
|
||||
ExpectIDs []string
|
||||
}{
|
||||
"swap in place": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddComputePayloadSHA256Middleware(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectIDs: []string{
|
||||
"ResolveEndpointV2",
|
||||
computePayloadHashMiddlewareID, // should snap to after resolve endpoint
|
||||
"before",
|
||||
"after",
|
||||
},
|
||||
},
|
||||
|
||||
"already unsigned payload exists": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = AddUnsignedPayloadMiddleware(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectIDs: []string{
|
||||
"ResolveEndpointV2",
|
||||
computePayloadHashMiddlewareID,
|
||||
"before",
|
||||
"after",
|
||||
},
|
||||
},
|
||||
|
||||
"no compute payload": {
|
||||
InitStep: func(s *middleware.Stack) (err error) {
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("before", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.Finalize.Add(middleware.FinalizeMiddlewareFunc("after", nil), middleware.After)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
|
||||
ExpectErr: "not found, " + computePayloadHashMiddlewareID,
|
||||
},
|
||||
}
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
stack := middleware.NewStack(t.Name(), smithyhttp.NewStackRequest)
|
||||
stack.Finalize.Add(&nopResolveEndpoint{}, middleware.After)
|
||||
if err := c.InitStep(stack); err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
err := c.Mutator(stack)
|
||||
if len(c.ExpectErr) != 0 {
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("expect error to contain %v, got %v", e, a)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(c.ExpectIDs, stack.Finalize.List()); len(diff) != 0 {
|
||||
t.Errorf("expect match\n%v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseDynamicPayloadSigningMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
content io.Reader
|
||||
url string
|
||||
expectedHash string
|
||||
expectedErr interface{}
|
||||
}{
|
||||
"TLS disabled": {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
url: "http://localhost.com/",
|
||||
expectedHash: "290f493c44f5d63d06b374d0a5abd292fae38b92cab2fae5efefe1b0e9347f56",
|
||||
},
|
||||
"TLS enabled": {
|
||||
content: func() io.Reader {
|
||||
br := bytes.NewReader([]byte("some content"))
|
||||
return br
|
||||
}(),
|
||||
url: "https://localhost.com/",
|
||||
expectedHash: "UNSIGNED-PAYLOAD",
|
||||
},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
c := &dynamicPayloadSigningMiddleware{}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||||
value := GetPayloadHash(ctx)
|
||||
if len(value) == 0 {
|
||||
t.Fatalf("expected payload hash value to be on context")
|
||||
}
|
||||
if e, a := tt.expectedHash, value; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
|
||||
req.URL, _ = url.Parse(tt.url)
|
||||
stream, err := req.SetStream(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
_, _, err = c.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: stream}, next)
|
||||
if err != nil && tt.expectedErr == nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
} else if err != nil && tt.expectedErr != nil {
|
||||
e, a := tt.expectedErr, err
|
||||
if !errors.As(a, &e) {
|
||||
t.Errorf("expected error type %T, got %T", e, a)
|
||||
}
|
||||
} else if err == nil && tt.expectedErr != nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type nonSeeker struct{}
|
||||
|
||||
func (nonSeeker) Read(p []byte) (n int, err error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
type semiSeekable struct {
|
||||
hasSeeked bool
|
||||
}
|
||||
|
||||
func (s *semiSeekable) Seek(offset int64, whence int) (int64, error) {
|
||||
if !s.hasSeeked {
|
||||
s.hasSeeked = true
|
||||
return 0, nil
|
||||
}
|
||||
return 0, fmt.Errorf("io seek error")
|
||||
}
|
||||
|
||||
func (*semiSeekable) Read(p []byte) (n int, err error) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
type nopResolveEndpoint struct{}
|
||||
|
||||
func (*nopResolveEndpoint) ID() string { return "ResolveEndpointV2" }
|
||||
|
||||
func (*nopResolveEndpoint) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
var (
|
||||
_ middleware.FinalizeMiddleware = &UnsignedPayload{}
|
||||
_ middleware.FinalizeMiddleware = &ComputePayloadSHA256{}
|
||||
_ middleware.FinalizeMiddleware = &ContentSHA256Header{}
|
||||
_ middleware.FinalizeMiddleware = &SignHTTPRequestMiddleware{}
|
||||
)
|
||||
127
aws/signer/v4/presign_middleware.go
Normal file
127
aws/signer/v4/presign_middleware.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyHTTP "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/versity/versitygw/aws/internal/sdk"
|
||||
)
|
||||
|
||||
// HTTPPresigner is an interface to a SigV4 signer that can sign create a
|
||||
// presigned URL for a HTTP requests.
|
||||
type HTTPPresigner interface {
|
||||
PresignHTTP(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error)
|
||||
}
|
||||
|
||||
// PresignedHTTPRequest provides the URL and signed headers that are included
|
||||
// in the presigned URL.
|
||||
type PresignedHTTPRequest struct {
|
||||
URL string
|
||||
Method string
|
||||
SignedHeader http.Header
|
||||
}
|
||||
|
||||
// PresignHTTPRequestMiddlewareOptions is the options for the PresignHTTPRequestMiddleware middleware.
|
||||
type PresignHTTPRequestMiddlewareOptions struct {
|
||||
CredentialsProvider aws.CredentialsProvider
|
||||
Presigner HTTPPresigner
|
||||
LogSigning bool
|
||||
}
|
||||
|
||||
// PresignHTTPRequestMiddleware provides the Finalize middleware for creating a
|
||||
// presigned URL for an HTTP request.
|
||||
//
|
||||
// Will short circuit the middleware stack and not forward onto the next
|
||||
// Finalize handler.
|
||||
type PresignHTTPRequestMiddleware struct {
|
||||
credentialsProvider aws.CredentialsProvider
|
||||
presigner HTTPPresigner
|
||||
logSigning bool
|
||||
}
|
||||
|
||||
// NewPresignHTTPRequestMiddleware returns a new PresignHTTPRequestMiddleware
|
||||
// initialized with the presigner.
|
||||
func NewPresignHTTPRequestMiddleware(options PresignHTTPRequestMiddlewareOptions) *PresignHTTPRequestMiddleware {
|
||||
return &PresignHTTPRequestMiddleware{
|
||||
credentialsProvider: options.CredentialsProvider,
|
||||
presigner: options.Presigner,
|
||||
logSigning: options.LogSigning,
|
||||
}
|
||||
}
|
||||
|
||||
// ID provides the middleware ID.
|
||||
func (*PresignHTTPRequestMiddleware) ID() string { return "PresignHTTPRequest" }
|
||||
|
||||
// HandleFinalize will take the provided input and create a presigned url for
|
||||
// the http request using the SigV4 presign authentication scheme.
|
||||
//
|
||||
// Since the signed request is not a valid HTTP request
|
||||
func (s *PresignHTTPRequestMiddleware) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyHTTP.Request)
|
||||
if !ok {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
|
||||
}
|
||||
}
|
||||
|
||||
httpReq := req.Build(ctx)
|
||||
if !haveCredentialProvider(s.credentialsProvider) {
|
||||
out.Result = &PresignedHTTPRequest{
|
||||
URL: httpReq.URL.String(),
|
||||
Method: httpReq.Method,
|
||||
SignedHeader: http.Header{},
|
||||
}
|
||||
|
||||
return out, metadata, nil
|
||||
}
|
||||
|
||||
signingName := awsmiddleware.GetSigningName(ctx)
|
||||
signingRegion := awsmiddleware.GetSigningRegion(ctx)
|
||||
payloadHash := GetPayloadHash(ctx)
|
||||
if len(payloadHash) == 0 {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("computed payload hash missing from context"),
|
||||
}
|
||||
}
|
||||
|
||||
credentials, err := s.credentialsProvider.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("failed to retrieve credentials: %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
u, h, err := s.presigner.PresignHTTP(ctx, credentials,
|
||||
httpReq, payloadHash, signingName, signingRegion, sdk.NowTime(),
|
||||
func(o *SignerOptions) {
|
||||
o.Logger = middleware.GetLogger(ctx)
|
||||
o.LogSigning = s.logSigning
|
||||
})
|
||||
if err != nil {
|
||||
return out, metadata, &SigningError{
|
||||
Err: fmt.Errorf("failed to sign http request, %w", err),
|
||||
}
|
||||
}
|
||||
|
||||
out.Result = &PresignedHTTPRequest{
|
||||
URL: u,
|
||||
Method: httpReq.Method,
|
||||
SignedHeader: h,
|
||||
}
|
||||
|
||||
return out, metadata, nil
|
||||
}
|
||||
224
aws/signer/v4/presign_middleware_test.go
Normal file
224
aws/signer/v4/presign_middleware_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/versity/versitygw/aws/internal/awstesting/unit"
|
||||
)
|
||||
|
||||
type httpPresignerFunc func(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error)
|
||||
|
||||
func (f httpPresignerFunc) PresignHTTP(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (
|
||||
url string, signedHeader http.Header, err error,
|
||||
) {
|
||||
return f(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
|
||||
}
|
||||
|
||||
func TestPresignHTTPRequestMiddleware(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
Request *http.Request
|
||||
Creds aws.CredentialsProvider
|
||||
PayloadHash string
|
||||
LogSigning bool
|
||||
ExpectResult *PresignedHTTPRequest
|
||||
ExpectErr string
|
||||
}{
|
||||
"success": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "0123456789abcdef",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"error": {
|
||||
Request: func() *http.Request {
|
||||
return &http.Request{}
|
||||
}(),
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "",
|
||||
ExpectErr: "failed to sign request",
|
||||
},
|
||||
"anonymous creds": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "",
|
||||
ExpectErr: "failed to sign request",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"nil creds": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: nil,
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
},
|
||||
"with log signing": {
|
||||
Request: &http.Request{
|
||||
URL: func() *url.URL {
|
||||
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||||
return u
|
||||
}(),
|
||||
Header: http.Header{},
|
||||
},
|
||||
Creds: unit.StubCredentialsProvider{},
|
||||
PayloadHash: "0123456789abcdef",
|
||||
ExpectResult: &PresignedHTTPRequest{
|
||||
URL: "https://example.aws/path?query=foo",
|
||||
SignedHeader: http.Header{},
|
||||
},
|
||||
|
||||
LogSigning: true,
|
||||
},
|
||||
}
|
||||
|
||||
const (
|
||||
signingName = "serviceId"
|
||||
signingRegion = "regionName"
|
||||
)
|
||||
|
||||
for name, c := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
m := &PresignHTTPRequestMiddleware{
|
||||
credentialsProvider: c.Creds,
|
||||
|
||||
presigner: httpPresignerFunc(func(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (url string, signedHeader http.Header, err error) {
|
||||
var options SignerOptions
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
if options.Logger == nil {
|
||||
t.Errorf("expect logger, got none")
|
||||
}
|
||||
if options.LogSigning {
|
||||
options.Logger.Logf(logging.Debug, t.Name())
|
||||
}
|
||||
|
||||
if !haveCredentialProvider(c.Creds) {
|
||||
t.Errorf("expect presigner not to be called for not credentials provider")
|
||||
}
|
||||
|
||||
expectCreds, _ := unit.StubCredentialsProvider{}.Retrieve(context.Background())
|
||||
if e, a := expectCreds, credentials; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := c.PayloadHash, payloadHash; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingName, service; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
if e, a := signingRegion, region; e != a {
|
||||
t.Errorf("expected %v, got %v", e, a)
|
||||
}
|
||||
|
||||
return c.ExpectResult.URL, c.ExpectResult.SignedHeader, nil
|
||||
}),
|
||||
logSigning: c.LogSigning,
|
||||
}
|
||||
|
||||
next := middleware.FinalizeHandlerFunc(
|
||||
func(ctx context.Context, in middleware.FinalizeInput) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
t.Errorf("expect next handler not to be called")
|
||||
return out, metadata, err
|
||||
})
|
||||
|
||||
ctx := awsmiddleware.SetSigningRegion(
|
||||
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||||
signingRegion)
|
||||
|
||||
var loggerBuf bytes.Buffer
|
||||
logger := logging.NewStandardLogger(&loggerBuf)
|
||||
ctx = middleware.SetLogger(ctx, logger)
|
||||
|
||||
if len(c.PayloadHash) != 0 {
|
||||
ctx = SetPayloadHash(ctx, c.PayloadHash)
|
||||
}
|
||||
|
||||
result, _, err := m.HandleFinalize(ctx, middleware.FinalizeInput{
|
||||
Request: &smithyhttp.Request{
|
||||
Request: c.Request,
|
||||
},
|
||||
}, next)
|
||||
if len(c.ExpectErr) != 0 {
|
||||
if err == nil {
|
||||
t.Fatalf("expect error, got none")
|
||||
}
|
||||
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("expect error to contain %v, got %v", e, a)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(c.ExpectResult, result.Result); len(diff) != 0 {
|
||||
t.Errorf("expect result match\n%v", diff)
|
||||
}
|
||||
|
||||
if c.LogSigning {
|
||||
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %v logged in %v", e, a)
|
||||
}
|
||||
} else {
|
||||
if loggerBuf.Len() != 0 {
|
||||
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ middleware.FinalizeMiddleware = &PresignHTTPRequestMiddleware{}
|
||||
)
|
||||
87
aws/signer/v4/stream.go
Normal file
87
aws/signer/v4/stream.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
// EventStreamSigner is an AWS EventStream protocol signer.
|
||||
type EventStreamSigner interface {
|
||||
GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error)
|
||||
}
|
||||
|
||||
// StreamSignerOptions is the configuration options for StreamSigner.
|
||||
type StreamSignerOptions struct{}
|
||||
|
||||
// StreamSigner implements Signature Version 4 (SigV4) signing of event stream encoded payloads.
|
||||
type StreamSigner struct {
|
||||
options StreamSignerOptions
|
||||
|
||||
credentials aws.Credentials
|
||||
service string
|
||||
region string
|
||||
|
||||
prevSignature []byte
|
||||
|
||||
signingKeyDeriver *v4Internal.SigningKeyDeriver
|
||||
}
|
||||
|
||||
// NewStreamSigner returns a new AWS EventStream protocol signer.
|
||||
func NewStreamSigner(credentials aws.Credentials, service, region string, seedSignature []byte, optFns ...func(*StreamSignerOptions)) *StreamSigner {
|
||||
o := StreamSignerOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&o)
|
||||
}
|
||||
|
||||
return &StreamSigner{
|
||||
options: o,
|
||||
credentials: credentials,
|
||||
service: service,
|
||||
region: region,
|
||||
signingKeyDeriver: v4Internal.NewSigningKeyDeriver(),
|
||||
prevSignature: seedSignature,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSignature signs the provided header and payload bytes.
|
||||
func (s *StreamSigner) GetSignature(ctx context.Context, headers, payload []byte, signingTime time.Time, optFns ...func(*StreamSignerOptions)) ([]byte, error) {
|
||||
options := s.options
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
prevSignature := s.prevSignature
|
||||
|
||||
st := v4Internal.NewSigningTime(signingTime)
|
||||
|
||||
sigKey := s.signingKeyDeriver.DeriveKey(s.credentials, s.service, s.region, st)
|
||||
|
||||
scope := v4Internal.BuildCredentialScope(st, s.region, s.service)
|
||||
|
||||
stringToSign := s.buildEventStreamStringToSign(headers, payload, prevSignature, scope, &st)
|
||||
|
||||
signature := v4Internal.HMACSHA256(sigKey, []byte(stringToSign))
|
||||
s.prevSignature = signature
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func (s *StreamSigner) buildEventStreamStringToSign(headers, payload, previousSignature []byte, credentialScope string, signingTime *v4Internal.SigningTime) string {
|
||||
hash := sha256.New()
|
||||
return strings.Join([]string{
|
||||
"AWS4-HMAC-SHA256-PAYLOAD",
|
||||
signingTime.TimeFormat(),
|
||||
credentialScope,
|
||||
hex.EncodeToString(previousSignature),
|
||||
hex.EncodeToString(makeHash(hash, headers)),
|
||||
hex.EncodeToString(makeHash(hash, payload)),
|
||||
}, "\n")
|
||||
}
|
||||
560
aws/signer/v4/v4.go
Normal file
560
aws/signer/v4/v4.go
Normal file
@@ -0,0 +1,560 @@
|
||||
// Package v4 implements signing for AWS V4 signer
|
||||
//
|
||||
// Provides request signing for request that need to be signed with
|
||||
// AWS V4 Signatures.
|
||||
//
|
||||
// # Standalone Signer
|
||||
//
|
||||
// Generally using the signer outside of the SDK should not require any additional
|
||||
//
|
||||
// The signer does this by taking advantage of the URL.EscapedPath method. If your request URI requires
|
||||
//
|
||||
// additional escaping you many need to use the URL.Opaque to define what the raw URI should be sent
|
||||
// to the service as.
|
||||
//
|
||||
// The signer will first check the URL.Opaque field, and use its value if set.
|
||||
// The signer does require the URL.Opaque field to be set in the form of:
|
||||
//
|
||||
// "//<hostname>/<path>"
|
||||
//
|
||||
// // e.g.
|
||||
// "//example.com/some/path"
|
||||
//
|
||||
// The leading "//" and hostname are required or the URL.Opaque escaping will
|
||||
// not work correctly.
|
||||
//
|
||||
// If URL.Opaque is not set the signer will fallback to the URL.EscapedPath()
|
||||
// method and using the returned value.
|
||||
//
|
||||
// AWS v4 signature validation requires that the canonical string's URI path
|
||||
// element must be the URI escaped form of the HTTP request's path.
|
||||
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
||||
//
|
||||
// The Go HTTP client will perform escaping automatically on the request. Some
|
||||
// of these escaping may cause signature validation errors because the HTTP
|
||||
// request differs from the URI path or query that the signature was generated.
|
||||
// https://golang.org/pkg/net/url/#URL.EscapedPath
|
||||
//
|
||||
// Because of this, it is recommended that when using the signer outside of the
|
||||
// SDK that explicitly escaping the request prior to being signed is preferable,
|
||||
// and will help prevent signature validation errors. This can be done by setting
|
||||
// the URL.Opaque or URL.RawPath. The SDK will use URL.Opaque first and then
|
||||
// call URL.EscapedPath() if Opaque is not set.
|
||||
//
|
||||
// Test `TestStandaloneSign` provides a complete example of using the signer
|
||||
// outside of the SDK and pre-escaping the URI path.
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/smithy-go/encoding/httpbinding"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
const (
|
||||
signingAlgorithm = "AWS4-HMAC-SHA256"
|
||||
authorizationHeader = "Authorization"
|
||||
|
||||
// Version of signing v4
|
||||
Version = "SigV4"
|
||||
)
|
||||
|
||||
// HTTPSigner is an interface to a SigV4 signer that can sign HTTP requests
|
||||
type HTTPSigner interface {
|
||||
SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*SignerOptions)) error
|
||||
}
|
||||
|
||||
type keyDerivator interface {
|
||||
DeriveKey(credential aws.Credentials, service, region string, signingTime v4Internal.SigningTime) []byte
|
||||
}
|
||||
|
||||
// SignerOptions is the SigV4 Signer options.
|
||||
type SignerOptions struct {
|
||||
// Disables the Signer's moving HTTP header key/value pairs from the HTTP
|
||||
// request header to the request's query string. This is most commonly used
|
||||
// with pre-signed requests preventing headers from being added to the
|
||||
// request's query string.
|
||||
DisableHeaderHoisting bool
|
||||
|
||||
// Disables the automatic escaping of the URI path of the request for the
|
||||
// siganture's canonical string's path. For services that do not need additional
|
||||
// escaping then use this to disable the signer escaping the path.
|
||||
//
|
||||
// S3 is an example of a service that does not need additional escaping.
|
||||
//
|
||||
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
||||
DisableURIPathEscaping bool
|
||||
|
||||
// The logger to send log messages to.
|
||||
Logger logging.Logger
|
||||
|
||||
// Enable logging of signed requests.
|
||||
// This will enable logging of the canonical request, the string to sign, and for presigning the subsequent
|
||||
// presigned URL.
|
||||
LogSigning bool
|
||||
|
||||
// Disables setting the session token on the request as part of signing
|
||||
// through X-Amz-Security-Token. This is needed for variations of v4 that
|
||||
// present the token elsewhere.
|
||||
DisableSessionToken bool
|
||||
}
|
||||
|
||||
// Signer applies AWS v4 signing to given request. Use this to sign requests
|
||||
// that need to be signed with AWS V4 Signatures.
|
||||
type Signer struct {
|
||||
options SignerOptions
|
||||
keyDerivator keyDerivator
|
||||
}
|
||||
|
||||
// NewSigner returns a new SigV4 Signer
|
||||
func NewSigner(optFns ...func(signer *SignerOptions)) *Signer {
|
||||
options := SignerOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
return &Signer{options: options, keyDerivator: v4Internal.NewSigningKeyDeriver()}
|
||||
}
|
||||
|
||||
type httpSigner struct {
|
||||
Request *http.Request
|
||||
ServiceName string
|
||||
Region string
|
||||
Time v4Internal.SigningTime
|
||||
Credentials aws.Credentials
|
||||
KeyDerivator keyDerivator
|
||||
IsPreSign bool
|
||||
|
||||
PayloadHash string
|
||||
|
||||
DisableHeaderHoisting bool
|
||||
DisableURIPathEscaping bool
|
||||
DisableSessionToken bool
|
||||
}
|
||||
|
||||
func (s *httpSigner) Build() (signedRequest, error) {
|
||||
req := s.Request
|
||||
|
||||
query := req.URL.Query()
|
||||
headers := req.Header
|
||||
|
||||
s.setRequiredSigningFields(headers, query)
|
||||
|
||||
// Sort Each Query Key's Values
|
||||
for key := range query {
|
||||
sort.Strings(query[key])
|
||||
}
|
||||
|
||||
v4Internal.SanitizeHostForHeader(req)
|
||||
|
||||
credentialScope := s.buildCredentialScope()
|
||||
credentialStr := s.Credentials.AccessKeyID + "/" + credentialScope
|
||||
if s.IsPreSign {
|
||||
query.Set(v4Internal.AmzCredentialKey, credentialStr)
|
||||
}
|
||||
|
||||
unsignedHeaders := headers
|
||||
if s.IsPreSign && !s.DisableHeaderHoisting {
|
||||
var urlValues url.Values
|
||||
urlValues, unsignedHeaders = buildQuery(v4Internal.AllowedQueryHoisting, headers)
|
||||
for k := range urlValues {
|
||||
query[k] = urlValues[k]
|
||||
}
|
||||
}
|
||||
|
||||
host := req.URL.Host
|
||||
if len(req.Host) > 0 {
|
||||
host = req.Host
|
||||
}
|
||||
|
||||
signedHeaders, signedHeadersStr, canonicalHeaderStr := s.buildCanonicalHeaders(host, v4Internal.IgnoredHeaders, unsignedHeaders, s.Request.ContentLength)
|
||||
|
||||
if s.IsPreSign {
|
||||
query.Set(v4Internal.AmzSignedHeadersKey, signedHeadersStr)
|
||||
}
|
||||
|
||||
var rawQuery strings.Builder
|
||||
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
|
||||
|
||||
canonicalURI := v4Internal.GetURIPath(req.URL)
|
||||
if !s.DisableURIPathEscaping {
|
||||
canonicalURI = httpbinding.EscapePath(canonicalURI, false)
|
||||
}
|
||||
|
||||
canonicalString := s.buildCanonicalString(
|
||||
req.Method,
|
||||
canonicalURI,
|
||||
rawQuery.String(),
|
||||
signedHeadersStr,
|
||||
canonicalHeaderStr,
|
||||
)
|
||||
|
||||
strToSign := s.buildStringToSign(credentialScope, canonicalString)
|
||||
signingSignature, err := s.buildSignature(strToSign)
|
||||
if err != nil {
|
||||
return signedRequest{}, err
|
||||
}
|
||||
|
||||
if s.IsPreSign {
|
||||
rawQuery.WriteString("&X-Amz-Signature=")
|
||||
rawQuery.WriteString(signingSignature)
|
||||
} else {
|
||||
headers[authorizationHeader] = append(headers[authorizationHeader][:0], buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature))
|
||||
}
|
||||
|
||||
req.URL.RawQuery = rawQuery.String()
|
||||
|
||||
return signedRequest{
|
||||
Request: req,
|
||||
SignedHeaders: signedHeaders,
|
||||
CanonicalString: canonicalString,
|
||||
StringToSign: strToSign,
|
||||
PreSigned: s.IsPreSign,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildAuthorizationHeader(credentialStr, signedHeadersStr, signingSignature string) string {
|
||||
const credential = "Credential="
|
||||
const signedHeaders = "SignedHeaders="
|
||||
const signature = "Signature="
|
||||
const commaSpace = ", "
|
||||
|
||||
var parts strings.Builder
|
||||
parts.Grow(len(signingAlgorithm) + 1 +
|
||||
len(credential) + len(credentialStr) + 2 +
|
||||
len(signedHeaders) + len(signedHeadersStr) + 2 +
|
||||
len(signature) + len(signingSignature),
|
||||
)
|
||||
parts.WriteString(signingAlgorithm)
|
||||
parts.WriteRune(' ')
|
||||
parts.WriteString(credential)
|
||||
parts.WriteString(credentialStr)
|
||||
parts.WriteString(commaSpace)
|
||||
parts.WriteString(signedHeaders)
|
||||
parts.WriteString(signedHeadersStr)
|
||||
parts.WriteString(commaSpace)
|
||||
parts.WriteString(signature)
|
||||
parts.WriteString(signingSignature)
|
||||
return parts.String()
|
||||
}
|
||||
|
||||
// SignHTTP signs AWS v4 requests with the provided payload hash, service name, region the
|
||||
// request is made to, and time the request is signed at. The signTime allows
|
||||
// you to specify that a request is signed for the future, and cannot be
|
||||
// used until then.
|
||||
//
|
||||
// The payloadHash is the hex encoded SHA-256 hash of the request payload, and
|
||||
// must be provided. Even if the request has no payload (aka body). If the
|
||||
// request has no payload you should use the hex encoded SHA-256 of an empty
|
||||
// string as the payloadHash value.
|
||||
//
|
||||
// "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
//
|
||||
// Some services such as Amazon S3 accept alternative values for the payload
|
||||
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
|
||||
// included in the request signature.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
|
||||
//
|
||||
// Sign differs from Presign in that it will sign the request using HTTP
|
||||
// header values. This type of signing is intended for http.Request values that
|
||||
// will not be shared, or are shared in a way the header values on the request
|
||||
// will not be lost.
|
||||
//
|
||||
// The passed in request will be modified in place.
|
||||
func (s Signer) SignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(options *SignerOptions)) error {
|
||||
options := s.options
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
signer := &httpSigner{
|
||||
Request: r,
|
||||
PayloadHash: payloadHash,
|
||||
ServiceName: service,
|
||||
Region: region,
|
||||
Credentials: credentials,
|
||||
Time: v4Internal.NewSigningTime(signingTime.UTC()),
|
||||
DisableHeaderHoisting: options.DisableHeaderHoisting,
|
||||
DisableURIPathEscaping: options.DisableURIPathEscaping,
|
||||
DisableSessionToken: options.DisableSessionToken,
|
||||
KeyDerivator: s.keyDerivator,
|
||||
}
|
||||
|
||||
signedRequest, err := signer.Build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logSigningInfo(ctx, options, &signedRequest, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PresignHTTP signs AWS v4 requests with the payload hash, service name, region
|
||||
// the request is made to, and time the request is signed at. The signTime
|
||||
// allows you to specify that a request is signed for the future, and cannot
|
||||
// be used until then.
|
||||
//
|
||||
// Returns the signed URL and the map of HTTP headers that were included in the
|
||||
// signature or an error if signing the request failed. For presigned requests
|
||||
// these headers and their values must be included on the HTTP request when it
|
||||
// is made. This is helpful to know what header values need to be shared with
|
||||
// the party the presigned request will be distributed to.
|
||||
//
|
||||
// The payloadHash is the hex encoded SHA-256 hash of the request payload, and
|
||||
// must be provided. Even if the request has no payload (aka body). If the
|
||||
// request has no payload you should use the hex encoded SHA-256 of an empty
|
||||
// string as the payloadHash value.
|
||||
//
|
||||
// "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
//
|
||||
// Some services such as Amazon S3 accept alternative values for the payload
|
||||
// hash, such as "UNSIGNED-PAYLOAD" for requests where the body will not be
|
||||
// included in the request signature.
|
||||
//
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
|
||||
//
|
||||
// PresignHTTP differs from SignHTTP in that it will sign the request using
|
||||
// query string instead of header values. This allows you to share the
|
||||
// Presigned Request's URL with third parties, or distribute it throughout your
|
||||
// system with minimal dependencies.
|
||||
//
|
||||
// PresignHTTP will not set the expires time of the presigned request
|
||||
// automatically. To specify the expire duration for a request add the
|
||||
// "X-Amz-Expires" query parameter on the request with the value as the
|
||||
// duration in seconds the presigned URL should be considered valid for. This
|
||||
// parameter is not used by all AWS services, and is most notable used by
|
||||
// Amazon S3 APIs.
|
||||
//
|
||||
// expires := 20 * time.Minute
|
||||
// query := req.URL.Query()
|
||||
// query.Set("X-Amz-Expires", strconv.FormatInt(int64(expires/time.Second), 10))
|
||||
// req.URL.RawQuery = query.Encode()
|
||||
//
|
||||
// This method does not modify the provided request.
|
||||
func (s *Signer) PresignHTTP(
|
||||
ctx context.Context, credentials aws.Credentials, r *http.Request,
|
||||
payloadHash string, service string, region string, signingTime time.Time,
|
||||
optFns ...func(*SignerOptions),
|
||||
) (signedURI string, signedHeaders http.Header, err error) {
|
||||
options := s.options
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
signer := &httpSigner{
|
||||
Request: r.Clone(r.Context()),
|
||||
PayloadHash: payloadHash,
|
||||
ServiceName: service,
|
||||
Region: region,
|
||||
Credentials: credentials,
|
||||
Time: v4Internal.NewSigningTime(signingTime.UTC()),
|
||||
IsPreSign: true,
|
||||
DisableHeaderHoisting: options.DisableHeaderHoisting,
|
||||
DisableURIPathEscaping: options.DisableURIPathEscaping,
|
||||
DisableSessionToken: options.DisableSessionToken,
|
||||
KeyDerivator: s.keyDerivator,
|
||||
}
|
||||
|
||||
signedRequest, err := signer.Build()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
logSigningInfo(ctx, options, &signedRequest, true)
|
||||
|
||||
signedHeaders = make(http.Header)
|
||||
|
||||
// For the signed headers we canonicalize the header keys in the returned map.
|
||||
// This avoids situations where can standard library double headers like host header. For example the standard
|
||||
// library will set the Host header, even if it is present in lower-case form.
|
||||
for k, v := range signedRequest.SignedHeaders {
|
||||
key := textproto.CanonicalMIMEHeaderKey(k)
|
||||
signedHeaders[key] = append(signedHeaders[key], v...)
|
||||
}
|
||||
|
||||
return signedRequest.Request.URL.String(), signedHeaders, nil
|
||||
}
|
||||
|
||||
func (s *httpSigner) buildCredentialScope() string {
|
||||
return v4Internal.BuildCredentialScope(s.Time, s.Region, s.ServiceName)
|
||||
}
|
||||
|
||||
func buildQuery(r v4Internal.Rule, header http.Header) (url.Values, http.Header) {
|
||||
query := url.Values{}
|
||||
unsignedHeaders := http.Header{}
|
||||
for k, h := range header {
|
||||
if r.IsValid(k) {
|
||||
query[k] = h
|
||||
} else {
|
||||
unsignedHeaders[k] = h
|
||||
}
|
||||
}
|
||||
|
||||
return query, unsignedHeaders
|
||||
}
|
||||
|
||||
func (s *httpSigner) buildCanonicalHeaders(host string, rule v4Internal.Rule, header http.Header, length int64) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
|
||||
signed = make(http.Header)
|
||||
|
||||
var headers []string
|
||||
const hostHeader = "host"
|
||||
headers = append(headers, hostHeader)
|
||||
signed[hostHeader] = append(signed[hostHeader], host)
|
||||
|
||||
const contentLengthHeader = "content-length"
|
||||
if length > 0 {
|
||||
headers = append(headers, contentLengthHeader)
|
||||
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
|
||||
}
|
||||
|
||||
for k, v := range header {
|
||||
if !rule.IsValid(k) {
|
||||
continue // ignored header
|
||||
}
|
||||
if strings.EqualFold(k, contentLengthHeader) {
|
||||
// prevent signing already handled content-length header.
|
||||
continue
|
||||
}
|
||||
|
||||
lowerCaseKey := strings.ToLower(k)
|
||||
if _, ok := signed[lowerCaseKey]; ok {
|
||||
// include additional values
|
||||
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
|
||||
continue
|
||||
}
|
||||
|
||||
headers = append(headers, lowerCaseKey)
|
||||
signed[lowerCaseKey] = v
|
||||
}
|
||||
sort.Strings(headers)
|
||||
|
||||
signedHeaders = strings.Join(headers, ";")
|
||||
|
||||
var canonicalHeaders strings.Builder
|
||||
n := len(headers)
|
||||
const colon = ':'
|
||||
for i := 0; i < n; i++ {
|
||||
if headers[i] == hostHeader {
|
||||
canonicalHeaders.WriteString(hostHeader)
|
||||
canonicalHeaders.WriteRune(colon)
|
||||
canonicalHeaders.WriteString(v4Internal.StripExcessSpaces(host))
|
||||
} else {
|
||||
canonicalHeaders.WriteString(headers[i])
|
||||
canonicalHeaders.WriteRune(colon)
|
||||
// Trim out leading, trailing, and dedup inner spaces from signed header values.
|
||||
values := signed[headers[i]]
|
||||
for j, v := range values {
|
||||
cleanedValue := strings.TrimSpace(v4Internal.StripExcessSpaces(v))
|
||||
canonicalHeaders.WriteString(cleanedValue)
|
||||
if j < len(values)-1 {
|
||||
canonicalHeaders.WriteRune(',')
|
||||
}
|
||||
}
|
||||
}
|
||||
canonicalHeaders.WriteRune('\n')
|
||||
}
|
||||
canonicalHeadersStr = canonicalHeaders.String()
|
||||
|
||||
return signed, signedHeaders, canonicalHeadersStr
|
||||
}
|
||||
|
||||
func (s *httpSigner) buildCanonicalString(method, uri, query, signedHeaders, canonicalHeaders string) string {
|
||||
return strings.Join([]string{
|
||||
method,
|
||||
uri,
|
||||
query,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
s.PayloadHash,
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (s *httpSigner) buildStringToSign(credentialScope, canonicalRequestString string) string {
|
||||
return strings.Join([]string{
|
||||
signingAlgorithm,
|
||||
s.Time.TimeFormat(),
|
||||
credentialScope,
|
||||
hex.EncodeToString(makeHash(sha256.New(), []byte(canonicalRequestString))),
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func makeHash(hash hash.Hash, b []byte) []byte {
|
||||
hash.Reset()
|
||||
hash.Write(b)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func (s *httpSigner) buildSignature(strToSign string) (string, error) {
|
||||
key := s.KeyDerivator.DeriveKey(s.Credentials, s.ServiceName, s.Region, s.Time)
|
||||
return hex.EncodeToString(v4Internal.HMACSHA256(key, []byte(strToSign))), nil
|
||||
}
|
||||
|
||||
func (s *httpSigner) setRequiredSigningFields(headers http.Header, query url.Values) {
|
||||
amzDate := s.Time.TimeFormat()
|
||||
|
||||
if s.IsPreSign {
|
||||
query.Set(v4Internal.AmzAlgorithmKey, signingAlgorithm)
|
||||
sessionToken := s.Credentials.SessionToken
|
||||
if !s.DisableSessionToken && len(sessionToken) > 0 {
|
||||
query.Set("X-Amz-Security-Token", sessionToken)
|
||||
}
|
||||
|
||||
query.Set(v4Internal.AmzDateKey, amzDate)
|
||||
return
|
||||
}
|
||||
|
||||
headers[v4Internal.AmzDateKey] = append(headers[v4Internal.AmzDateKey][:0], amzDate)
|
||||
|
||||
if !s.DisableSessionToken && len(s.Credentials.SessionToken) > 0 {
|
||||
headers[v4Internal.AmzSecurityTokenKey] = append(headers[v4Internal.AmzSecurityTokenKey][:0], s.Credentials.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
func logSigningInfo(ctx context.Context, options SignerOptions, request *signedRequest, isPresign bool) {
|
||||
if !options.LogSigning {
|
||||
return
|
||||
}
|
||||
signedURLMsg := ""
|
||||
if isPresign {
|
||||
signedURLMsg = fmt.Sprintf(logSignedURLMsg, request.Request.URL.String())
|
||||
}
|
||||
logger := logging.WithContext(ctx, options.Logger)
|
||||
logger.Logf(logging.Debug, logSignInfoMsg, request.CanonicalString, request.StringToSign, signedURLMsg)
|
||||
}
|
||||
|
||||
type signedRequest struct {
|
||||
Request *http.Request
|
||||
SignedHeaders http.Header
|
||||
CanonicalString string
|
||||
StringToSign string
|
||||
PreSigned bool
|
||||
}
|
||||
|
||||
const logSignInfoMsg = `Request Signature:
|
||||
---[ CANONICAL STRING ]-----------------------------
|
||||
%s
|
||||
---[ STRING TO SIGN ]--------------------------------
|
||||
%s%s
|
||||
-----------------------------------------------------`
|
||||
const logSignedURLMsg = `
|
||||
---[ SIGNED URL ]------------------------------------
|
||||
%s`
|
||||
356
aws/signer/v4/v4_test.go
Normal file
356
aws/signer/v4/v4_test.go
Normal file
@@ -0,0 +1,356 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v4Internal "github.com/versity/versitygw/aws/signer/internal/v4"
|
||||
)
|
||||
|
||||
var testCredentials = aws.Credentials{AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "SESSION"}
|
||||
|
||||
func buildRequest(serviceName, region, body string) (*http.Request, string) {
|
||||
reader := strings.NewReader(body)
|
||||
return buildRequestWithBodyReader(serviceName, region, reader)
|
||||
}
|
||||
|
||||
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, string) {
|
||||
var bodyLen int
|
||||
|
||||
type lenner interface {
|
||||
Len() int
|
||||
}
|
||||
if lr, ok := body.(lenner); ok {
|
||||
bodyLen = lr.Len()
|
||||
}
|
||||
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
req, _ := http.NewRequest("POST", endpoint, body)
|
||||
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
|
||||
req.Header.Set("X-Amz-Target", "prefix.Operation")
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
|
||||
if bodyLen > 0 {
|
||||
req.ContentLength = int64(bodyLen)
|
||||
}
|
||||
|
||||
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
|
||||
h := sha256.New()
|
||||
_, _ = io.Copy(h, body)
|
||||
payloadHash := hex.EncodeToString(h.Sum(nil))
|
||||
|
||||
return req, payloadHash
|
||||
}
|
||||
|
||||
func TestPresignRequest(t *testing.T) {
|
||||
req, body := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set("X-Amz-Expires", "300")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
signer := NewSigner()
|
||||
signed, headers, err := signer.PresignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
|
||||
expectedSig := "122f0b9e091e4ba84286097e2b3404a1f1f4c4aad479adda95b7dff0ccbe5581"
|
||||
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
q, err := url.ParseQuery(signed[strings.Index(signed, "?"):])
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty", a)
|
||||
}
|
||||
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
for _, h := range strings.Split(expectedHeaders, ";") {
|
||||
v := headers.Get(h)
|
||||
if len(v) == 0 {
|
||||
t.Errorf("expect %v, to be present in header map", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresignBodyWithArrayRequest(t *testing.T) {
|
||||
req, body := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set("X-Amz-Expires", "300")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
signer := NewSigner()
|
||||
signed, headers, err := signer.PresignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
q, err := url.ParseQuery(signed[strings.Index(signed, "?"):])
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedHeaders := "content-length;content-type;host;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore"
|
||||
expectedSig := "e3ac55addee8711b76c6d608d762cff285fe8b627a057f8b5ec9268cf82c08b1"
|
||||
expectedCred := "AKID/19700101/us-east-1/dynamodb/aws4_request"
|
||||
expectedTarget := "prefix.Operation"
|
||||
|
||||
if e, a := expectedSig, q.Get("X-Amz-Signature"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedCred, q.Get("X-Amz-Credential"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedHeaders, q.Get("X-Amz-SignedHeaders"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if a := q.Get("X-Amz-Meta-Other-Header"); len(a) != 0 {
|
||||
t.Errorf("expect %v to be empty, was not", a)
|
||||
}
|
||||
if e, a := expectedTarget, q.Get("X-Amz-Target"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
|
||||
for _, h := range strings.Split(expectedHeaders, ";") {
|
||||
v := headers.Get(h)
|
||||
if len(v) == 0 {
|
||||
t.Errorf("expect %v, to be present in header map", h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
req, body := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
signer := NewSigner()
|
||||
err := signer.SignHTTP(context.Background(), testCredentials, req, body, "dynamodb", "us-east-1", time.Unix(0, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
|
||||
|
||||
q := req.Header
|
||||
if e, a := expectedSig, q.Get("Authorization"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCanonicalRequest(t *testing.T) {
|
||||
req, _ := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
|
||||
ctx := &httpSigner{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Time: v4Internal.NewSigningTime(time.Now()),
|
||||
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
|
||||
}
|
||||
|
||||
build, err := ctx.Build()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=a&Foo=m&Foo=o&Foo=z"
|
||||
if e, a := expected, build.Request.URL.String(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_SignHTTP_NoReplaceRequestBody(t *testing.T) {
|
||||
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
|
||||
s := NewSigner()
|
||||
|
||||
origBody := req.Body
|
||||
|
||||
err := s.SignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if req.Body != origBody {
|
||||
t.Errorf("expect request body to not be chagned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHost(t *testing.T) {
|
||||
req, _ := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
req.Host = "myhost"
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set("X-Amz-Expires", "5")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
ctx := &httpSigner{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Time: v4Internal.NewSigningTime(time.Now()),
|
||||
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
|
||||
}
|
||||
|
||||
build, err := ctx.Build()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(build.CanonicalString, "host:"+req.Host) {
|
||||
t.Errorf("canonical host header invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSign_buildCanonicalHeadersContentLengthPresent(t *testing.T) {
|
||||
body := `{"description": "this is a test"}`
|
||||
req, _ := buildRequest("dynamodb", "us-east-1", body)
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
req.Host = "myhost"
|
||||
|
||||
contentLength := fmt.Sprintf("%d", len([]byte(body)))
|
||||
req.Header.Add("Content-Length", contentLength)
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set("X-Amz-Expires", "5")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
ctx := &httpSigner{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Time: v4Internal.NewSigningTime(time.Now()),
|
||||
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
|
||||
}
|
||||
|
||||
build, err := ctx.Build()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(build.CanonicalString, "content-length:"+contentLength+"\n") {
|
||||
t.Errorf("canonical header content-length invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSign_buildCanonicalHeaders(t *testing.T) {
|
||||
serviceName := "mockAPI"
|
||||
region := "mock-region"
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request, %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("FooInnerSpace", " inner space ")
|
||||
req.Header.Set("FooLeadingSpace", " leading-space")
|
||||
req.Header.Add("FooMultipleSpace", "no-space")
|
||||
req.Header.Add("FooMultipleSpace", "\ttab-space")
|
||||
req.Header.Add("FooMultipleSpace", "trailing-space ")
|
||||
req.Header.Set("FooNoSpace", "no-space")
|
||||
req.Header.Set("FooTabSpace", "\ttab-space\t")
|
||||
req.Header.Set("FooTrailingSpace", "trailing-space ")
|
||||
req.Header.Set("FooWrappedSpace", " wrapped-space ")
|
||||
|
||||
ctx := &httpSigner{
|
||||
ServiceName: serviceName,
|
||||
Region: region,
|
||||
Request: req,
|
||||
Time: v4Internal.NewSigningTime(time.Date(2021, 10, 20, 12, 42, 0, 0, time.UTC)),
|
||||
KeyDerivator: v4Internal.NewSigningKeyDeriver(),
|
||||
}
|
||||
|
||||
build, err := ctx.Build()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
expectCanonicalString := strings.Join([]string{
|
||||
`POST`,
|
||||
`/`,
|
||||
``,
|
||||
`fooinnerspace:inner space`,
|
||||
`fooleadingspace:leading-space`,
|
||||
`foomultiplespace:no-space,tab-space,trailing-space`,
|
||||
`foonospace:no-space`,
|
||||
`footabspace:tab-space`,
|
||||
`footrailingspace:trailing-space`,
|
||||
`foowrappedspace:wrapped-space`,
|
||||
`host:mockAPI.mock-region.amazonaws.com`,
|
||||
`x-amz-date:20211020T124200Z`,
|
||||
``,
|
||||
`fooinnerspace;fooleadingspace;foomultiplespace;foonospace;footabspace;footrailingspace;foowrappedspace;host;x-amz-date`,
|
||||
``,
|
||||
}, "\n")
|
||||
if diff := cmp.Diff(expectCanonicalString, build.CanonicalString); diff != "" {
|
||||
t.Errorf("expect match, got\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPresignRequest(b *testing.B) {
|
||||
signer := NewSigner()
|
||||
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
|
||||
query := req.URL.Query()
|
||||
query.Set("X-Amz-Expires", "5")
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
signer.PresignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSignRequest(b *testing.B) {
|
||||
signer := NewSigner()
|
||||
req, bodyHash := buildRequest("dynamodb", "us-east-1", "{}")
|
||||
for i := 0; i < b.N; i++ {
|
||||
signer.SignHTTP(context.Background(), testCredentials, req, bodyHash, "dynamodb", "us-east-1", time.Now())
|
||||
}
|
||||
}
|
||||
1
aws/staticcheck.conf
Normal file
1
aws/staticcheck.conf
Normal file
@@ -0,0 +1 @@
|
||||
checks = []
|
||||
1
go.mod
1
go.mod
@@ -32,6 +32,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.28.1 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/nats-io/nkeys v0.4.7 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -71,6 +71,8 @@ github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
|
||||
Reference in New Issue
Block a user