Files
versitygw/tests/rest_scripts/command/s3Request.go
2026-03-27 15:00:56 -03:00

260 lines
8.3 KiB
Go

package command
import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/versity/versitygw/tests/rest_scripts/logger"
"os"
"sort"
"strings"
"time"
)
type S3Request struct {
Config *S3RequestConfigData
currentDateTime string
yearMonthDay string
path string
signedParamString string
signature string
host string
headerValues []*HeaderValue
payloadHash string
dataSource DataSource
canonicalRequestHash string
signingKey []byte
}
func (s *S3Request) CalculateDateTimeParams() {
now := time.Now().UTC()
if s.Config.CustomDate != "" {
s.currentDateTime = s.Config.CustomDate
} else if s.Config.IncorrectYearMonthDay {
s.currentDateTime = now.Add(-48 * time.Hour).Format("20060102T150405Z")
} else {
s.currentDateTime = now.Format("20060102T150405Z")
}
s.yearMonthDay = strings.Split(s.currentDateTime, "T")[0]
if s.Config.InvalidYearMonthDay {
s.yearMonthDay = s.yearMonthDay[:len(s.yearMonthDay)-2]
}
}
func (s *S3Request) DeriveHost() error {
protocolAndHost := strings.Split(s.Config.Url, "://")
if len(protocolAndHost) != 2 {
return fmt.Errorf("invalid URL value: %s", s.Config.Url)
}
s.host = protocolAndHost[1]
return nil
}
func (s *S3Request) DeriveBucketAndKeyPath() {
s.path = "/" + s.Config.BucketName
if s.Config.ObjectKey != "" {
s.path += "/" + s.Config.ObjectKey
}
}
func (s *S3Request) CalculateSignature() error {
if err := s.calculateCanonicalRequestHash(); err != nil {
return fmt.Errorf("error calculating canonical request hash: %w", err)
}
thirdLine := fmt.Sprintf("%s/%s/%s/aws4_request", s.yearMonthDay, s.Config.AwsRegion, s.Config.ServiceName)
stsDataLines := []string{
s.Config.AuthorizationScheme,
s.currentDateTime,
thirdLine,
s.canonicalRequestHash,
}
stsDataString := strings.Join(stsDataLines, "\n")
// Derive signing key step by step
dateKey := hmacSHA256([]byte("AWS4"+s.Config.AwsSecretAccessKey), s.yearMonthDay)
dateRegionKey := hmacSHA256(dateKey, s.Config.AwsRegion)
dateRegionServiceKey := hmacSHA256(dateRegionKey, s.Config.ServiceName)
s.signingKey = hmacSHA256(dateRegionServiceKey, "aws4_request")
// Generate signature
signatureBytes := hmacSHA256(s.signingKey, stsDataString)
if s.Config.IncorrectSignature {
if signatureBytes[0] == 'a' {
signatureBytes[0] = 'A'
} else {
signatureBytes[0] = 'a'
}
}
// Print hex-encoded signature
s.signature = hex.EncodeToString(signatureBytes)
return nil
}
func (s *S3Request) calculateCanonicalRequestHash() error {
canonicalRequestLines := []string{s.Config.Method}
s.path = encodeS3Key(s.path)
canonicalRequestLines = append(canonicalRequestLines, s.path)
canonicalQuery, err := s.getCanonicalQuery()
if err != nil {
return fmt.Errorf("error generating canoncial query: %w", err)
}
canonicalRequestLines = append(canonicalRequestLines, canonicalQuery)
var signedParams []string
for _, headerValue := range s.headerValues {
if headerValue.Signed {
key := strings.ToLower(headerValue.Key)
canonicalRequestLines = append(canonicalRequestLines, key+":"+headerValue.Value)
signedParams = append(signedParams, key)
}
}
canonicalRequestLines = append(canonicalRequestLines, "")
s.signedParamString = strings.Join(signedParams, ";")
canonicalRequestLines = append(canonicalRequestLines, s.signedParamString, s.payloadHash)
canonicalRequestString := strings.Join(canonicalRequestLines, "\n")
logger.PrintDebug("Canonical request string: %s\n", canonicalRequestString)
canonicalRequestHashBytes := sha256.Sum256([]byte(canonicalRequestString))
s.canonicalRequestHash = hex.EncodeToString(canonicalRequestHashBytes[:])
return nil
}
func (s *S3Request) getCanonicalQuery() (string, error) {
var queryRequestLine string
if strings.Contains(s.Config.Query, "&") {
queries := strings.Split(s.Config.Query, "&")
if !strings.HasSuffix(queries[0], "=") && !strings.Contains(queries[0], "=") {
queries[0] += "="
queryRequestLine = strings.Join(queries, "&")
}
} else if s.Config.Query != "" && !strings.HasSuffix(s.Config.Query, "=") && !strings.Contains(s.Config.Query, "=") {
queryRequestLine = s.Config.Query + "="
}
if queryRequestLine == "" {
queryRequestLine = s.Config.Query
}
canonicalQuery, err := canonicalizeQuery(queryRequestLine)
if err != nil {
return "", fmt.Errorf("error parsing query '%s': %v", queryRequestLine, err)
}
return canonicalQuery, nil
}
func (s *S3Request) performBasePayloadCalculations() error {
if s.Config.PayloadFile != "" {
s.dataSource = NewFileDataSource(s.Config.PayloadFile)
} else if s.Config.Payload != "" {
s.dataSource = NewStringDataSource(s.Config.Payload)
}
if s.Config.CustomSHA256Hash != "" {
s.payloadHash = s.Config.CustomSHA256Hash
} else if s.Config.PayloadType != "" {
s.payloadHash = s.Config.PayloadType
} else if s.dataSource != nil {
var err error
s.payloadHash, err = s.dataSource.CalculateSHA256HashString()
if err != nil {
return fmt.Errorf("error calculating sha256 hash: %w", err)
}
} else {
s.payloadHash = SHA256HashZeroBytes
}
return nil
}
func (s *S3Request) deriveUniversalHeaderValues() {
if s.Config.MissingHostParam {
s.headerValues = append(s.headerValues, &HeaderValue{"host", "", true})
} else if s.Config.CustomHostParamSet {
s.headerValues = append(s.headerValues, &HeaderValue{"host", s.Config.CustomHostParam, true})
} else {
s.headerValues = append(s.headerValues, &HeaderValue{"host", s.host, true})
}
if !s.Config.OmitSHA256Hash {
s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-content-sha256", s.payloadHash, true})
}
if !s.Config.OmitDate {
s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-date", s.currentDateTime, true})
}
}
func (s *S3Request) deriveConfigSpecificHeaderValues() error {
if s.Config.PayloadType == StreamingAWS4HMACSHA256PayloadTrailer && s.Config.ChecksumType != "" {
s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-trailer", fmt.Sprintf("x-amz-checksum-%s", s.Config.ChecksumType), true})
}
if s.dataSource != nil && s.Config.PayloadType != UnsignedPayload {
payloadSize, err := s.dataSource.SourceDataByteSize()
if err != nil {
return fmt.Errorf("error getting payload size: %w", err)
}
s.headerValues = append(s.headerValues,
&HeaderValue{"x-amz-decoded-content-length", fmt.Sprintf("%d", payloadSize), true})
}
for key, value := range s.Config.SignedParams {
s.headerValues = append(s.headerValues, &HeaderValue{key, value, true})
}
if s.Config.ContentMD5 || s.Config.IncorrectContentMD5 || s.Config.CustomContentMD5 != "" {
if err := s.addContentMD5Header(); err != nil {
return fmt.Errorf("error adding Content-MD5 header: %w", err)
}
}
for key, value := range s.Config.UnsignedParams {
s.headerValues = append(s.headerValues, &HeaderValue{key, value, false})
}
sort.Slice(s.headerValues,
func(i, j int) bool {
return strings.ToLower(s.headerValues[i].Key) < strings.ToLower(s.headerValues[j].Key)
})
return nil
}
func (s *S3Request) addContentMD5Header() error {
var payloadData []byte
var err error
if s.Config.PayloadFile != "" {
if payloadData, err = os.ReadFile(s.Config.PayloadFile); err != nil {
return fmt.Errorf("error reading file %s: %w", s.Config.PayloadFile, err)
}
} else {
logger.PrintDebug("Payload: %s", s.Config.Payload)
payloadData = []byte(strings.Replace(s.Config.Payload, "\\", "", -1))
}
var contentMD5 string
if s.Config.CustomContentMD5 != "" {
contentMD5 = s.Config.CustomContentMD5
} else {
hasher := md5.New()
hasher.Write(payloadData)
md5Hash := hasher.Sum(nil)
if s.Config.IncorrectContentMD5 {
modifyHash(md5Hash)
}
contentMD5 = base64.StdEncoding.EncodeToString(md5Hash)
}
s.headerValues = append(s.headerValues, &HeaderValue{"Content-MD5", contentMD5, true})
return nil
}
func (s *S3Request) buildAuthorizationString() string {
var credentialString string
if s.Config.IncorrectCredential == "" {
credentialString = fmt.Sprintf("%s/%s/%s/%s/aws4_request", s.Config.AwsAccessKeyId, s.yearMonthDay, s.Config.AwsRegion, s.Config.ServiceName)
} else {
credentialString = s.Config.IncorrectCredential
}
return fmt.Sprintf("Authorization: %s Credential=%s,SignedHeaders=%s,Signature=%s",
s.Config.AuthorizationScheme, credentialString, s.signedParamString, s.signature)
}