From ccd0ba2762df1dd826c5afdd64f86418c30fcbfc Mon Sep 17 00:00:00 2001 From: Luke McCrone Date: Fri, 27 Mar 2026 09:18:52 -0300 Subject: [PATCH] test: remove eval, go command generation overhaul --- tests/drivers/rest.sh | 35 +- tests/rest_scripts/command/Renderer.go | 11 + .../command/createBucketCommand.go | 16 +- tests/rest_scripts/command/curlRequest.go | 90 ++++ tests/rest_scripts/command/openSSLRequest.go | 110 ++++ .../command/putBucketCorsCommand.go | 8 +- .../command/putBucketTaggingCommand.go | 6 +- .../rest_scripts/command/putObjectCommand.go | 10 +- .../command/putObjectTaggingCommand.go | 8 +- .../rest_scripts/command/putTaggingCommand.go | 12 +- tests/rest_scripts/command/s3Command.go | 478 ------------------ tests/rest_scripts/command/s3Request.go | 259 ++++++++++ .../rest_scripts/command/s3RequestBuilder.go | 169 +++++++ tests/rest_scripts/config/config.go | 1 - tests/rest_scripts/generateCommand.go | 82 +-- 15 files changed, 736 insertions(+), 559 deletions(-) create mode 100644 tests/rest_scripts/command/Renderer.go create mode 100644 tests/rest_scripts/command/curlRequest.go create mode 100644 tests/rest_scripts/command/openSSLRequest.go delete mode 100644 tests/rest_scripts/command/s3Command.go create mode 100644 tests/rest_scripts/command/s3Request.go create mode 100644 tests/rest_scripts/command/s3RequestBuilder.go delete mode 100644 tests/rest_scripts/config/config.go diff --git a/tests/drivers/rest.sh b/tests/drivers/rest.sh index 3c0f4889..fc81d9c8 100644 --- a/tests/drivers/rest.sh +++ b/tests/drivers/rest.sh @@ -61,7 +61,11 @@ send_rest_command() { fi log 5 "output file: $output_file" else - output_file="$TEST_FILE_FOLDER/output.txt" + if ! file_name=$(get_file_name 2>&1); then + log 2 "error getting file name: $file_name" + return 1 + fi + output_file="$TEST_FILE_FOLDER/$file_name" fi local env_array=("env" "COMMAND_LOG=$COMMAND_LOG" "OUTPUT_FILE=$output_file") if [ "$1" != "" ]; then @@ -175,17 +179,22 @@ send_rest_command_expect_success_callback() { } rest_go_command_perform_send() { - if ! curl_command=$(go run ./tests/rest_scripts/generateCommand.go -awsAccessKeyId "$AWS_ACCESS_KEY_ID" -awsSecretAccessKey "$AWS_SECRET_ACCESS_KEY" -awsRegion "$AWS_REGION" -url "$AWS_ENDPOINT_URL" "$@" 2>&1); then + if ! xml_file=$(get_file_name 2>&1); then + log 2 "error getting XML file name: $xml_file" + return 1 + fi + if ! curl_command=$(go run ./tests/rest_scripts/generateCommand.go -awsAccessKeyId "$AWS_ACCESS_KEY_ID" -awsSecretAccessKey "$AWS_SECRET_ACCESS_KEY" -awsRegion "$AWS_REGION" -url "$AWS_ENDPOINT_URL" "-writeXMLPayloadToFile" "$xml_file" "$@" 2>&1); then log 2 "error: $curl_command" return 1 fi - local full_command="send_command $curl_command" - log 5 "full command: $full_command" - if ! result=$(eval "${full_command[*]}" 2>&1); then + curl_command=$(echo -n "$curl_command" | tr -d '\n') + mapfile -t curl_command_array < <( + printf '%s' "$curl_command" | python3 -c 'import shlex, sys; [print(arg) for arg in shlex.split(sys.stdin.read())]' + ) + if ! result=$(send_command "${curl_command_array[@]}" 2>&1); then log 2 "error sending command: $result" return 1 fi - log 5 "result: $result" echo "$result" } @@ -372,14 +381,18 @@ check_for_header_key_and_value() { } check_argument_name_and_value() { - if ! check_param_count_v2 "data file" 1 $#; then + if ! check_param_count_v2 "data file, argument name, argument value" 3 $#; then return 1 fi - if ! check_error_parameter "$1" "ArgumentName" "$argument_name"; then + if ! xml_data=$(print_xml_data_to_file "$1" 2>&1); then + log 2 "error getting XML data: $xml_data" + return 1 + fi + if ! check_error_parameter "$xml_data" "ArgumentName" "$2"; then log 2 "error checking 'ArgumentName' parameter" return 1 fi - if ! check_error_parameter "$1" "ArgumentValue" "$argument_value"; then + if ! check_error_parameter "$xml_data" "ArgumentValue" "$3"; then log 2 "error checking 'ArgumentValue' parameter" return 1 fi @@ -390,9 +403,7 @@ send_rest_go_command_expect_error_with_arg_name_value() { if ! check_param_count_gt "response code, error code, message, arg name, arg value, params" 5 $#; then return 1 fi - argument_name=$4 - argument_value=$5 - if ! send_rest_go_command_expect_error_callback "$1" "$2" "$3" "check_argument_name_and_value" "${@:6}"; then + if ! send_rest_go_command_expect_error_callback "$1" "$2" "$3" "check_argument_name_and_value" "${@:6}" "--" "$4" "$5"; then log 2 "error checking error response values" return 1 fi diff --git a/tests/rest_scripts/command/Renderer.go b/tests/rest_scripts/command/Renderer.go new file mode 100644 index 00000000..d6c4170d --- /dev/null +++ b/tests/rest_scripts/command/Renderer.go @@ -0,0 +1,11 @@ +package command + +type Renderer interface { + CalculateDateTimeParams() + DeriveHost() error + DeriveBucketAndKeyPath() + PerformPayloadCalculations() error + DeriveHeaderValues() error + CalculateSignature() error + Render() error +} diff --git a/tests/rest_scripts/command/createBucketCommand.go b/tests/rest_scripts/command/createBucketCommand.go index f6eca8ba..182ef964 100644 --- a/tests/rest_scripts/command/createBucketCommand.go +++ b/tests/rest_scripts/command/createBucketCommand.go @@ -13,16 +13,16 @@ type CreateBucketCommandXML struct { } type CreateBucketCommand struct { - *S3Command + *S3RequestBuilder Config *CreateBucketCommandXML } -func NewCreateBucketCommand(s3Command *S3Command, locationConstraint string, constraintSet bool) (*CreateBucketCommand, error) { - if s3Command.BucketName == "" { +func NewCreateBucketCommand(s3Command *S3RequestBuilder, locationConstraint string, constraintSet bool) (*CreateBucketCommand, error) { + if s3Command.Config.BucketName == "" { return nil, errors.New("CreateBucket must have bucket name") } - s3Command.Method = "PUT" - s3Command.Query = "" + s3Command.Config.Method = "PUT" + s3Command.Config.Query = "" var config *CreateBucketCommandXML = nil if constraintSet { config = &CreateBucketCommandXML{ @@ -31,15 +31,15 @@ func NewCreateBucketCommand(s3Command *S3Command, locationConstraint string, con } } command := &CreateBucketCommand{ - S3Command: s3Command, - Config: config, + S3RequestBuilder: s3Command, + Config: config, } if constraintSet { xmlData, err := xml.Marshal(command.Config) if err != nil { return nil, fmt.Errorf("error marshalling XML: %w", err) } - command.Payload = "\n" + string(xmlData) + s3Command.Config.Payload = "\n" + string(xmlData) } return command, nil } diff --git a/tests/rest_scripts/command/curlRequest.go b/tests/rest_scripts/command/curlRequest.go new file mode 100644 index 00000000..99b88371 --- /dev/null +++ b/tests/rest_scripts/command/curlRequest.go @@ -0,0 +1,90 @@ +package command + +import ( + "fmt" + "github.com/versity/versitygw/tests/rest_scripts/logger" + "net/url" + "os" + "path/filepath" + "strings" +) + +type CurlCommand struct { + *S3Request + + curlCommandString string +} + +func (c *CurlCommand) PerformPayloadCalculations() error { + return c.performBasePayloadCalculations() +} + +func (c *CurlCommand) DeriveHeaderValues() error { + c.deriveUniversalHeaderValues() + if err := c.deriveConfigSpecificHeaderValues(); err != nil { + return fmt.Errorf("error deriving config-specific header values: %w", err) + } + return nil +} + +func (c *CurlCommand) Render() error { + curlOpts := "-iks" + if c.Config.Method == "HEAD" { + curlOpts += "I" + } + curlCommand := []string{"curl", curlOpts} + if c.Config.Method != "GET" { + curlCommand = append(curlCommand, fmt.Sprintf("-X %s ", c.Config.Method)) + } + fullPath := c.Config.Url + c.path + awsUrl, err := url.Parse(fullPath) + if err != nil { + return fmt.Errorf("error parsing URL: %w", err) + } + if c.Config.Query != "" { + canonicalQuery, err := canonicalizeQuery(c.Config.Query) + if err != nil { + return fmt.Errorf("error parsing query: %w", err) + } + awsUrl.RawQuery = canonicalQuery + } + + enclosedPath := fmt.Sprintf("\"%s\"", awsUrl.String()) + curlCommand = append(curlCommand, enclosedPath) + authorizationString := c.buildAuthorizationString() + curlCommand = append(curlCommand, "-H", fmt.Sprintf("\"%s\"", authorizationString)) + for _, headerValue := range c.headerValues { + headerString := fmt.Sprintf("\"%s: %s\"", headerValue.Key, headerValue.Value) + curlCommand = append(curlCommand, "-H", headerString) + } + if c.Config.PayloadFile != "" { + curlCommand = append(curlCommand, "-T", fmt.Sprintf("\"%s\"", c.Config.PayloadFile)) + } else if c.Config.Payload != "" { + var err error + curlCommand, err = c.appendCurlPayload(curlCommand) + if err != nil { + return err + } + } + c.curlCommandString = strings.Join(curlCommand, " ") + logger.PrintDebug("curl command: %s", c.curlCommandString) + return nil +} + +func (c *CurlCommand) String() string { + return c.curlCommandString +} + +func (c *CurlCommand) appendCurlPayload(curlCommand []string) ([]string, error) { + if c.Config.WriteXMLPayloadToFile == "" { + return nil, fmt.Errorf("curl XML payloads must be written to file with 'writeXMLPayloadToFile' param") + } + if err := os.MkdirAll(filepath.Dir(c.Config.WriteXMLPayloadToFile), 0o755); err != nil { + return nil, fmt.Errorf("error creating payload folder: %w", err) + } + if err := os.WriteFile(c.Config.WriteXMLPayloadToFile, []byte(c.Config.Payload), 0o644); err != nil { + return nil, fmt.Errorf("error writing payload to file '%s': %w", c.Config.WriteXMLPayloadToFile, err) + } + curlCommand = append(curlCommand, "-H", "\"Content-Type: application/xml\"", "--data-binary", fmt.Sprintf("\"@%s\"", c.Config.WriteXMLPayloadToFile)) + return curlCommand, nil +} diff --git a/tests/rest_scripts/command/openSSLRequest.go b/tests/rest_scripts/command/openSSLRequest.go new file mode 100644 index 00000000..acf5b539 --- /dev/null +++ b/tests/rest_scripts/command/openSSLRequest.go @@ -0,0 +1,110 @@ +package command + +import ( + "fmt" + "github.com/versity/versitygw/tests/rest_scripts/logger" + "os" + "strings" +) + +type OpenSSLCommand struct { + *S3Request + + payloadManager OpenSSLPayloadManager + contentLength int64 +} + +func (o *OpenSSLCommand) PerformPayloadCalculations() error { + if err := o.performBasePayloadCalculations(); err != nil { + return fmt.Errorf("error performing base payload calculations: %w", err) + } + if err := o.initializePayloadAndGetContentLength(); err != nil { + return fmt.Errorf("error initializing openssl-specific payload: %w", err) + } + return nil +} + +func (o *OpenSSLCommand) DeriveHeaderValues() error { + o.deriveUniversalHeaderValues() + if !o.Config.OmitContentLength { + o.headerValues = append(o.headerValues, + &HeaderValue{"Content-Length", fmt.Sprintf("%d", o.contentLength), true}) + } + if err := o.deriveConfigSpecificHeaderValues(); err != nil { + return fmt.Errorf("error deriving config-specific header values: %w", err) + } + return nil +} + +func (o *OpenSSLCommand) Render() error { + if o.Config.Query != "" { + o.path += "?" + o.Config.Query + } + openSSLCommand := []string{fmt.Sprintf("%s %s HTTP/1.1", o.Config.Method, o.path)} + openSSLCommand = append(openSSLCommand, o.buildAuthorizationString()) + for _, headerValue := range o.headerValues { + if headerValue.Key == "host" && o.Config.MissingHostParam { + continue + } + openSSLCommand = append(openSSLCommand, fmt.Sprintf("%s:%s", headerValue.Key, headerValue.Value)) + } + + file, err := os.Create(o.Config.FilePath) + if err != nil { + return fmt.Errorf("error opening file: %w", err) + } + defer func() { + file.Close() + }() + openSSLCommandBytes := []byte(strings.Join(openSSLCommand, "\r\n")) + if _, err = file.Write(openSSLCommandBytes); err != nil { + return fmt.Errorf("error writing to file: %w", err) + } + if _, err := file.Write([]byte{'\r', '\n', '\r', '\n'}); err != nil { + return fmt.Errorf("error writing to file: %w", err) + } + if o.Config.PayloadFile != "" || o.Config.Payload != "" { + if err = o.writePayload(file); err != nil { + return fmt.Errorf("error writing openssl payload: %w", err) + } + } + return nil +} + +func (o *OpenSSLCommand) writePayload(file *os.File) error { + if awsPayload, ok := o.payloadManager.(*PayloadStreamingAWS4HMACSHA256); ok { + awsPayload.AddInitialSignatureAndSigningKey(o.signature, o.signingKey) + } + switch o.Config.PayloadType { + case UnsignedPayload, "", StreamingUnsignedPayloadTrailer, StreamingAWS4HMACSHA256Payload, StreamingAWS4HMACSHA256PayloadTrailer: + if err := o.payloadManager.WritePayload(o.Config.FilePath); err != nil { + return fmt.Errorf("error writing payload to openssl file: %w", err) + } + default: + return fmt.Errorf("unsupported payload type: %s", o.Config.PayloadType) + } + return nil +} + +func (o *OpenSSLCommand) initializePayloadAndGetContentLength() error { + switch o.Config.PayloadType { + case StreamingAWS4HMACSHA256Payload, StreamingAWS4HMACSHA256PayloadTrailer: + serviceString := fmt.Sprintf("%s/%s/%s/aws4_request", o.yearMonthDay, o.Config.AwsRegion, o.Config.ServiceName) + o.payloadManager = NewPayloadStreamingAWS4HMACSHA256(o.dataSource, int64(o.Config.ChunkSize), PayloadType(o.Config.PayloadType), serviceString, o.currentDateTime, o.yearMonthDay, o.Config.ChecksumType) + case StreamingUnsignedPayloadTrailer: + streamingUnsignedPayloadTrailerImpl := NewStreamingUnsignedPayloadWithTrailer(o.dataSource, int64(o.Config.ChunkSize), o.Config.ChecksumType) + streamingUnsignedPayloadTrailerImpl.OmitTrailerOrKey(o.Config.OmitPayloadTrailer, o.Config.OmitPayloadTrailerKey) + o.payloadManager = streamingUnsignedPayloadTrailerImpl + case UnsignedPayload, "": + o.payloadManager = NewWholePayload(o.dataSource) + default: + return fmt.Errorf("unsupported OpenSSL payload type: '%s'", o.Config.PayloadType) + } + var err error + o.contentLength, err = o.payloadManager.GetContentLength() + if err != nil { + return fmt.Errorf("error calculating Content-Length: %w", err) + } + logger.PrintDebug("Predicted payload size: %d", o.contentLength) + return nil +} diff --git a/tests/rest_scripts/command/putBucketCorsCommand.go b/tests/rest_scripts/command/putBucketCorsCommand.go index b773cdbd..ea178081 100644 --- a/tests/rest_scripts/command/putBucketCorsCommand.go +++ b/tests/rest_scripts/command/putBucketCorsCommand.go @@ -32,9 +32,9 @@ type CORSConfiguration struct { CORSRules []*CORSRule } -func NewPutBucketCORSCommand(command *S3Command, ruleStrings []string) (*S3Command, error) { - command.Method = "PUT" - command.Query = "cors" +func NewPutBucketCORSCommand(command *S3RequestBuilder, ruleStrings []string) (*S3RequestBuilder, error) { + command.Config.Method = "PUT" + command.Config.Query = "cors" corsConfiguration := &CORSConfiguration{ XMLNamespace: "https://s3.amazonaws.com/doc/2006-03-01/", } @@ -50,7 +50,7 @@ func NewPutBucketCORSCommand(command *S3Command, ruleStrings []string) (*S3Comma if err != nil { return nil, fmt.Errorf("error marshalling XML: %w", err) } - command.Payload = "\n" + string(xmlData) + command.Config.Payload = "\n" + string(xmlData) return command, nil } diff --git a/tests/rest_scripts/command/putBucketTaggingCommand.go b/tests/rest_scripts/command/putBucketTaggingCommand.go index 663fdf3e..034be85c 100644 --- a/tests/rest_scripts/command/putBucketTaggingCommand.go +++ b/tests/rest_scripts/command/putBucketTaggingCommand.go @@ -9,13 +9,13 @@ type PutBucketTaggingCommand struct { *PutTaggingCommand } -func NewPutBucketTaggingCommand(s3Command *S3Command, fields *TaggingFields) (*PutBucketTaggingCommand, error) { - if s3Command.BucketName == "" { +func NewPutBucketTaggingCommand(s3Command *S3RequestBuilder, fields *TaggingFields) (*PutBucketTaggingCommand, error) { + if s3Command.Config.BucketName == "" { return nil, errors.New("PutBucketTagging must have bucket name") } command := &PutBucketTaggingCommand{ &PutTaggingCommand{ - S3Command: s3Command, + S3RequestBuilder: s3Command, }, } if err := command.createTaggingPayload(fields); err != nil { diff --git a/tests/rest_scripts/command/putObjectCommand.go b/tests/rest_scripts/command/putObjectCommand.go index d0cca003..8e5c88a0 100644 --- a/tests/rest_scripts/command/putObjectCommand.go +++ b/tests/rest_scripts/command/putObjectCommand.go @@ -4,14 +4,14 @@ import ( "errors" ) -func NewPutObjectCommand(s3Command *S3Command) (*S3Command, error) { - if s3Command.BucketName == "" { +func NewPutObjectCommand(s3Command *S3RequestBuilder) (*S3RequestBuilder, error) { + if s3Command.Config.BucketName == "" { return nil, errors.New("PutObject must have bucket name") } - if s3Command.ObjectKey == "" { + if s3Command.Config.ObjectKey == "" { return nil, errors.New("PutObject must have object key") } - s3Command.Method = "PUT" - s3Command.Query = "" + s3Command.Config.Method = "PUT" + s3Command.Config.Query = "" return s3Command, nil } diff --git a/tests/rest_scripts/command/putObjectTaggingCommand.go b/tests/rest_scripts/command/putObjectTaggingCommand.go index c986cce9..5fee6432 100644 --- a/tests/rest_scripts/command/putObjectTaggingCommand.go +++ b/tests/rest_scripts/command/putObjectTaggingCommand.go @@ -9,16 +9,16 @@ type PutObjectTaggingCommand struct { *PutTaggingCommand } -func NewPutObjectTaggingCommand(s3Command *S3Command, fields *TaggingFields) (*PutObjectTaggingCommand, error) { - if s3Command.BucketName == "" { +func NewPutObjectTaggingCommand(s3Command *S3RequestBuilder, fields *TaggingFields) (*PutObjectTaggingCommand, error) { + if s3Command.Config.BucketName == "" { return nil, errors.New("PutObjectTagging must have bucket name") } - if s3Command.ObjectKey == "" { + if s3Command.Config.ObjectKey == "" { return nil, errors.New("PutObjectTagging must have object key") } command := &PutObjectTaggingCommand{ &PutTaggingCommand{ - S3Command: s3Command, + S3RequestBuilder: s3Command, }, } if err := command.createTaggingPayload(fields); err != nil { diff --git a/tests/rest_scripts/command/putTaggingCommand.go b/tests/rest_scripts/command/putTaggingCommand.go index 4045b6b3..4a1652a6 100644 --- a/tests/rest_scripts/command/putTaggingCommand.go +++ b/tests/rest_scripts/command/putTaggingCommand.go @@ -7,17 +7,17 @@ import ( ) type PutTaggingCommand struct { - *S3Command + *S3RequestBuilder TagCount *int Tags *Tagging } func (p *PutTaggingCommand) createTaggingPayload(fields *TaggingFields) error { - p.Method = "PUT" - if p.Query != "" { - p.Query = "tagging=&" + p.Query + p.Config.Method = "PUT" + if p.Config.Query != "" { + p.Config.Query = "tagging=&" + p.Config.Query } else { - p.Query = "tagging=" + p.Config.Query = "tagging=" } if len(fields.TagKeys) != len(fields.TagValues) { return errors.New("must be same number of tag keys and tag values") @@ -40,6 +40,6 @@ func (p *PutTaggingCommand) createTaggingPayload(fields *TaggingFields) error { if err != nil { return fmt.Errorf("error marshalling XML: %w", err) } - p.Payload = "\n" + string(xmlData) + p.Config.Payload = "\n" + string(xmlData) return nil } diff --git a/tests/rest_scripts/command/s3Command.go b/tests/rest_scripts/command/s3Command.go deleted file mode 100644 index 00b21821..00000000 --- a/tests/rest_scripts/command/s3Command.go +++ /dev/null @@ -1,478 +0,0 @@ -package command - -import ( - "crypto/hmac" - "crypto/md5" - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "fmt" - "net/url" - "os" - "sort" - "strings" - "time" - - logger "github.com/versity/versitygw/tests/rest_scripts/logger" -) - -const ( - CURL = "curl" - OPENSSL = "openssl" -) - -const ( - UnsignedPayload = "UNSIGNED-PAYLOAD" - StreamingAWS4HMACSHA256Payload = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" - StreamingAWS4HMACSHA256PayloadTrailer = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER" - StreamingUnsignedPayloadTrailer = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" - StreamingAWS4ECDSAP256SHA256Payload = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD" - StreamingAWS4ECDSAP256SHA256PayloadTrailer = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD-TRAILER" -) - -type PayloadType string - -const ( - ChecksumCRC32 = "crc32" - ChecksumCRC32C = "crc32c" - ChecksumCRC64NVME = "crc64nvme" - ChecksumSHA1 = "sha1" - ChecksumSHA256 = "sha256" -) - -const SHA256HashZeroBytes = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - -type HeaderValue struct { - Key string - Value string - Signed bool -} - -type S3Command struct { - Client string - Method string - Url string - BucketName string - ObjectKey string - Query string - AwsRegion string - AwsAccessKeyId string - AwsSecretAccessKey string - ServiceName string - SignedParams map[string]string - UnsignedParams map[string]string - PayloadFile string - IncorrectSignature bool - AuthorizationHeaderMalformed bool - AuthorizationScheme string - IncorrectCredential string - IncorrectYearMonthDay bool - InvalidYearMonthDay bool - Payload string - ContentMD5 bool - IncorrectContentMD5 bool - CustomContentMD5 string - MissingHostParam bool - FilePath string - CustomHostParam string - CustomHostParamSet bool - PayloadType string - ChunkSize int - ChecksumType string - OmitPayloadTrailer bool - OmitPayloadTrailerKey bool - OmitContentLength bool - OmitSHA256Hash bool - CustomSHA256Hash string - OmitDate bool - CustomDate string - - dataSource DataSource - currentDateTime string - host string - payloadHash string - headerValues []*HeaderValue - canonicalRequestHash string - path string - signedParamString string - yearMonthDay string - signature string - signingKey []byte - contentLength int64 - payloadOpenSSL OpenSSLPayloadManager -} - -func (s *S3Command) OpenSSLCommand() error { - if err := s.prepareForBuild(); err != nil { - return fmt.Errorf("error preparing for command building: %w", err) - } - if err := s.buildOpenSSLCommand(); err != nil { - return fmt.Errorf("error building openSSL command: %w", err) - } - return nil -} - -func (s *S3Command) CurlShellCommand() (string, error) { - if err := s.prepareForBuild(); err != nil { - return "", fmt.Errorf("error preparing for command building: %w", err) - } - return s.buildCurlShellCommand() -} - -func (s *S3Command) prepareForBuild() error { - now := time.Now().UTC() - if s.CustomDate != "" { - s.currentDateTime = s.CustomDate - } else if s.IncorrectYearMonthDay { - s.currentDateTime = now.Add(-48 * time.Hour).Format("20060102T150405Z") - } else { - s.currentDateTime = now.Format("20060102T150405Z") - } - protocolAndHost := strings.Split(s.Url, "://") - if len(protocolAndHost) != 2 { - return fmt.Errorf("invalid URL value: %s", s.Url) - } - s.host = protocolAndHost[1] - s.yearMonthDay = strings.Split(s.currentDateTime, "T")[0] - if s.InvalidYearMonthDay { - s.yearMonthDay = s.yearMonthDay[:len(s.yearMonthDay)-2] - } - s.path = "/" + s.BucketName - if s.ObjectKey != "" { - s.path += "/" + s.ObjectKey - } - if err := s.preparePayload(); err != nil { - return fmt.Errorf("error preparing payload: %w", err) - } - if err := s.addHeaderValues(); err != nil { - return fmt.Errorf("error adding header values: %w", err) - } - s.generateCanonicalRequestString() - s.getStsSignature() - return nil -} - -func (s *S3Command) preparePayload() error { - if s.PayloadFile != "" { - s.dataSource = NewFileDataSource(s.PayloadFile) - } else if s.Payload != "" { - s.dataSource = NewStringDataSource(s.Payload) - } - if s.CustomSHA256Hash != "" { - s.payloadHash = s.CustomSHA256Hash - } else if s.PayloadType != "" { - s.payloadHash = s.PayloadType - } else if s.dataSource != nil { - var err error - s.payloadHash, err = s.dataSource.CalculateSHA256HashString() - if err != nil { - return fmt.Errorf("error calculating sha256 hash") - } - } else { - s.payloadHash = SHA256HashZeroBytes - } - if s.Client == OPENSSL { - if err := s.initializeOpenSSLPayloadAndGetContentLength(); err != nil { - return fmt.Errorf("error initializing openssl payload: %w", err) - } - } - return nil -} - -func (s *S3Command) initializeOpenSSLPayloadAndGetContentLength() error { - switch s.PayloadType { - case StreamingAWS4HMACSHA256Payload, StreamingAWS4HMACSHA256PayloadTrailer: - serviceString := fmt.Sprintf("%s/%s/%s/aws4_request", s.yearMonthDay, s.AwsRegion, s.ServiceName) - s.payloadOpenSSL = NewPayloadStreamingAWS4HMACSHA256(s.dataSource, int64(s.ChunkSize), PayloadType(s.PayloadType), serviceString, s.currentDateTime, s.yearMonthDay, s.ChecksumType) - case StreamingUnsignedPayloadTrailer: - streamingUnsignedPayloadTrailerImpl := NewStreamingUnsignedPayloadWithTrailer(s.dataSource, int64(s.ChunkSize), s.ChecksumType) - streamingUnsignedPayloadTrailerImpl.OmitTrailerOrKey(s.OmitPayloadTrailer, s.OmitPayloadTrailerKey) - s.payloadOpenSSL = streamingUnsignedPayloadTrailerImpl - case UnsignedPayload, "": - s.payloadOpenSSL = NewWholePayload(s.dataSource) - default: - return fmt.Errorf("unsupported OpenSSL payload type: '%s'", s.PayloadType) - } - var err error - s.contentLength, err = s.payloadOpenSSL.GetContentLength() - if err != nil { - return fmt.Errorf("error calculating Content-Length: %w", err) - } - logger.PrintDebug("Predicted payload size: %d", s.contentLength) - return nil -} - -func (s *S3Command) addBaseHeaderValues() { - if s.MissingHostParam { - s.headerValues = append(s.headerValues, &HeaderValue{"host", "", true}) - } else if s.CustomHostParamSet { - s.headerValues = append(s.headerValues, &HeaderValue{"host", s.CustomHostParam, true}) - } else { - s.headerValues = append(s.headerValues, &HeaderValue{"host", s.host, true}) - } - if !s.OmitSHA256Hash { - s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-content-sha256", s.payloadHash, true}) - } - if !s.OmitDate { - s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-date", s.currentDateTime, true}) - } - if s.Client == OPENSSL && !s.OmitContentLength { - s.headerValues = append(s.headerValues, - &HeaderValue{"Content-Length", fmt.Sprintf("%d", s.contentLength), true}) - } -} - -func (s *S3Command) addHeaderValues() error { - s.headerValues = []*HeaderValue{} - - s.addBaseHeaderValues() - - if s.PayloadType == StreamingAWS4HMACSHA256PayloadTrailer && s.ChecksumType != "" { - s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-trailer", fmt.Sprintf("x-amz-checksum-%s", s.ChecksumType), true}) - } - if s.dataSource != nil && s.PayloadType != UnsignedPayload { - payloadSize, err := s.dataSource.SourceDataByteSize() - if err != nil { - return fmt.Errorf("error getting payload size: %w", err) - } - s.headerValues = append(s.headerValues, - &HeaderValue{"x-amz-decoded-content-length", fmt.Sprintf("%d", payloadSize), true}) - } - for key, value := range s.SignedParams { - s.headerValues = append(s.headerValues, &HeaderValue{key, value, true}) - } - if s.ContentMD5 || s.IncorrectContentMD5 || s.CustomContentMD5 != "" { - if err := s.addContentMD5Header(); err != nil { - return fmt.Errorf("error adding Content-MD5 header: %w", err) - } - } - for key, value := range s.UnsignedParams { - s.headerValues = append(s.headerValues, &HeaderValue{key, value, false}) - } - sort.Slice(s.headerValues, - func(i, j int) bool { - return strings.ToLower(s.headerValues[i].Key) < strings.ToLower(s.headerValues[j].Key) - }) - return nil -} - -func (s *S3Command) modifyHash(md5Hash []byte) { - if md5Hash[0] == 'a' { - md5Hash[0] = 'A' - } else { - md5Hash[0] = 'a' - } -} - -func (s *S3Command) addContentMD5Header() error { - var payloadData []byte - var err error - if s.PayloadFile != "" { - if payloadData, err = os.ReadFile(s.PayloadFile); err != nil { - return fmt.Errorf("error reading file %s: %w", s.PayloadFile, err) - } - } else { - logger.PrintDebug("Payload: %s", s.Payload) - payloadData = []byte(strings.Replace(s.Payload, "\\", "", -1)) - } - - var contentMD5 string - if s.CustomContentMD5 != "" { - contentMD5 = s.CustomContentMD5 - } else { - hasher := md5.New() - hasher.Write(payloadData) - md5Hash := hasher.Sum(nil) - if s.IncorrectContentMD5 { - s.modifyHash(md5Hash) - } - contentMD5 = base64.StdEncoding.EncodeToString(md5Hash) - } - - s.headerValues = append(s.headerValues, &HeaderValue{"Content-MD5", contentMD5, true}) - return nil -} - -func (s *S3Command) generateCanonicalRequestString() { - canonicalRequestLines := []string{s.Method} - - canonicalRequestLines = append(canonicalRequestLines, s.path) - var queryRequestLine string - if strings.Contains(s.Query, "&") { - queries := strings.Split(s.Query, "&") - if !strings.HasSuffix(queries[0], "=") && !strings.Contains(queries[0], "=") { - queries[0] += "=" - queryRequestLine = strings.Join(queries, "&") - } - } else if s.Query != "" && !strings.HasSuffix(s.Query, "=") && !strings.Contains(s.Query, "=") { - queryRequestLine = s.Query + "=" - } - if queryRequestLine == "" { - queryRequestLine = s.Query - } - canonicalQuery, err := canonicalizeQuery(queryRequestLine) - if err != nil { - logger.PrintDebug("error parsing query '%s': %v", queryRequestLine, err) - canonicalQuery = queryRequestLine - } - canonicalRequestLines = append(canonicalRequestLines, canonicalQuery) - - var signedParams []string - for _, headerValue := range s.headerValues { - if headerValue.Signed { - key := strings.ToLower(headerValue.Key) - canonicalRequestLines = append(canonicalRequestLines, key+":"+headerValue.Value) - signedParams = append(signedParams, key) - } - } - - canonicalRequestLines = append(canonicalRequestLines, "") - s.signedParamString = strings.Join(signedParams, ";") - canonicalRequestLines = append(canonicalRequestLines, s.signedParamString, s.payloadHash) - - canonicalRequestString := strings.Join(canonicalRequestLines, "\n") - logger.PrintDebug("Canonical request string: %s\n", canonicalRequestString) - - canonicalRequestHashBytes := sha256.Sum256([]byte(canonicalRequestString)) - s.canonicalRequestHash = hex.EncodeToString(canonicalRequestHashBytes[:]) -} - -func (s *S3Command) getStsSignature() { - thirdLine := fmt.Sprintf("%s/%s/%s/aws4_request", s.yearMonthDay, s.AwsRegion, s.ServiceName) - stsDataLines := []string{ - s.AuthorizationScheme, - s.currentDateTime, - thirdLine, - s.canonicalRequestHash, - } - stsDataString := strings.Join(stsDataLines, "\n") - - // Derive signing key step by step - dateKey := hmacSHA256([]byte("AWS4"+s.AwsSecretAccessKey), s.yearMonthDay) - dateRegionKey := hmacSHA256(dateKey, s.AwsRegion) - dateRegionServiceKey := hmacSHA256(dateRegionKey, s.ServiceName) - s.signingKey = hmacSHA256(dateRegionServiceKey, "aws4_request") - - // Generate signature - signatureBytes := hmacSHA256(s.signingKey, stsDataString) - if s.IncorrectSignature { - if signatureBytes[0] == 'a' { - signatureBytes[0] = 'A' - } else { - signatureBytes[0] = 'a' - } - } - - // Print hex-encoded signature - s.signature = hex.EncodeToString(signatureBytes) -} - -func (s *S3Command) buildCurlShellCommand() (string, error) { - if s.MissingHostParam { - return "", fmt.Errorf("missingHostParam option only available for OpenSSL commands") - } - curlOpts := "-iks" - if s.Method == "HEAD" { - curlOpts += "I" - } - curlCommand := []string{"curl", curlOpts} - if s.Method != "GET" { - curlCommand = append(curlCommand, fmt.Sprintf("-X %s ", s.Method)) - } - fullPath := s.Url + s.path - awsUrl, err := url.Parse(fullPath) - if err != nil { - return "", fmt.Errorf("error parsing URL: %w", err) - } - if s.Query != "" { - canonicalQuery, err := canonicalizeQuery(s.Query) - if err != nil { - return "", fmt.Errorf("error parsing query: %w", err) - } - awsUrl.RawQuery = canonicalQuery - } - enclosedPath := fmt.Sprintf("\"%s\"", awsUrl.String()) - curlCommand = append(curlCommand, enclosedPath) - authorizationString := s.buildAuthorizationString() - curlCommand = append(curlCommand, "-H", fmt.Sprintf("\"%s\"", authorizationString)) - for _, headerValue := range s.headerValues { - headerString := fmt.Sprintf("\"%s: %s\"", headerValue.Key, headerValue.Value) - curlCommand = append(curlCommand, "-H", headerString) - } - if s.PayloadFile != "" { - curlCommand = append(curlCommand, "-T", s.PayloadFile) - } else if s.Payload != "" { - s.Payload = strings.Replace(s.Payload, "\"", "\\\"", -1) - curlCommand = append(curlCommand, "-H", "\"Content-Type: application/xml\"", "-d", fmt.Sprintf("\"%s\"", s.Payload)) - } - curlStringCommand := strings.Join(curlCommand, " ") - logger.PrintDebug("curl command: %s", curlStringCommand) - return curlStringCommand, nil -} - -func (s *S3Command) buildAuthorizationString() string { - var credentialString string - if s.IncorrectCredential == "" { - credentialString = fmt.Sprintf("%s/%s/%s/%s/aws4_request", s.AwsAccessKeyId, s.yearMonthDay, s.AwsRegion, s.ServiceName) - } else { - credentialString = s.IncorrectCredential - } - return fmt.Sprintf("Authorization: %s Credential=%s,SignedHeaders=%s,Signature=%s", - s.AuthorizationScheme, credentialString, s.signedParamString, s.signature) -} - -func (s *S3Command) buildOpenSSLCommand() error { - if s.Query != "" { - s.path += "?" + s.Query - } - openSSLCommand := []string{fmt.Sprintf("%s %s HTTP/1.1", s.Method, s.path)} - openSSLCommand = append(openSSLCommand, s.buildAuthorizationString()) - for _, headerValue := range s.headerValues { - if headerValue.Key == "host" && s.MissingHostParam { - continue - } - openSSLCommand = append(openSSLCommand, fmt.Sprintf("%s:%s", headerValue.Key, headerValue.Value)) - } - - file, err := os.Create(s.FilePath) - if err != nil { - return fmt.Errorf("error opening file: %w", err) - } - defer func() { - file.Close() - }() - openSSLCommandBytes := []byte(strings.Join(openSSLCommand, "\r\n")) - if _, err = file.Write(openSSLCommandBytes); err != nil { - return fmt.Errorf("error writing to file: %w", err) - } - if _, err := file.Write([]byte{'\r', '\n', '\r', '\n'}); err != nil { - return fmt.Errorf("error writing to file: %w", err) - } - if s.PayloadFile != "" || s.Payload != "" { - if err = s.writeOpenSSLPayload(file); err != nil { - return fmt.Errorf("error writing openssl payload: %w", err) - } - } - return nil -} - -func (s *S3Command) writeOpenSSLPayload(file *os.File) error { - if awsPayload, ok := s.payloadOpenSSL.(*PayloadStreamingAWS4HMACSHA256); ok { - awsPayload.AddInitialSignatureAndSigningKey(s.signature, s.signingKey) - } - switch s.PayloadType { - case UnsignedPayload, "", StreamingUnsignedPayloadTrailer, StreamingAWS4HMACSHA256Payload, StreamingAWS4HMACSHA256PayloadTrailer: - if err := s.payloadOpenSSL.WritePayload(s.FilePath); err != nil { - return fmt.Errorf("error writing payload to openssl file: %w", err) - } - default: - return fmt.Errorf("unsupported payload type: %s", s.PayloadType) - } - return nil -} - -func hmacSHA256(key []byte, data string) []byte { - h := hmac.New(sha256.New, key) - h.Write([]byte(data)) - return h.Sum(nil) -} diff --git a/tests/rest_scripts/command/s3Request.go b/tests/rest_scripts/command/s3Request.go new file mode 100644 index 00000000..a21a34c2 --- /dev/null +++ b/tests/rest_scripts/command/s3Request.go @@ -0,0 +1,259 @@ +package command + +import ( + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "github.com/versity/versitygw/tests/rest_scripts/logger" + "os" + "sort" + "strings" + "time" +) + +type S3Request struct { + Config *S3RequestConfigData + + currentDateTime string + yearMonthDay string + path string + signedParamString string + signature string + host string + headerValues []*HeaderValue + payloadHash string + dataSource DataSource + canonicalRequestHash string + signingKey []byte +} + +func (s *S3Request) CalculateDateTimeParams() { + now := time.Now().UTC() + if s.Config.CustomDate != "" { + s.currentDateTime = s.Config.CustomDate + } else if s.Config.IncorrectYearMonthDay { + s.currentDateTime = now.Add(-48 * time.Hour).Format("20060102T150405Z") + } else { + s.currentDateTime = now.Format("20060102T150405Z") + } + s.yearMonthDay = strings.Split(s.currentDateTime, "T")[0] + if s.Config.InvalidYearMonthDay { + s.yearMonthDay = s.yearMonthDay[:len(s.yearMonthDay)-2] + } +} + +func (s *S3Request) DeriveHost() error { + protocolAndHost := strings.Split(s.Config.Url, "://") + if len(protocolAndHost) != 2 { + return fmt.Errorf("invalid URL value: %s", s.Config.Url) + } + s.host = protocolAndHost[1] + return nil +} + +func (s *S3Request) DeriveBucketAndKeyPath() { + s.path = "/" + s.Config.BucketName + if s.Config.ObjectKey != "" { + s.path += "/" + s.Config.ObjectKey + } +} + +func (s *S3Request) CalculateSignature() error { + + if err := s.calculateCanonicalRequestHash(); err != nil { + return fmt.Errorf("error calculating canonical request hash: %w", err) + } + + thirdLine := fmt.Sprintf("%s/%s/%s/aws4_request", s.yearMonthDay, s.Config.AwsRegion, s.Config.ServiceName) + stsDataLines := []string{ + s.Config.AuthorizationScheme, + s.currentDateTime, + thirdLine, + s.canonicalRequestHash, + } + stsDataString := strings.Join(stsDataLines, "\n") + + // Derive signing key step by step + dateKey := hmacSHA256([]byte("AWS4"+s.Config.AwsSecretAccessKey), s.yearMonthDay) + dateRegionKey := hmacSHA256(dateKey, s.Config.AwsRegion) + dateRegionServiceKey := hmacSHA256(dateRegionKey, s.Config.ServiceName) + s.signingKey = hmacSHA256(dateRegionServiceKey, "aws4_request") + + // Generate signature + signatureBytes := hmacSHA256(s.signingKey, stsDataString) + if s.Config.IncorrectSignature { + if signatureBytes[0] == 'a' { + signatureBytes[0] = 'A' + } else { + signatureBytes[0] = 'a' + } + } + + // Print hex-encoded signature + s.signature = hex.EncodeToString(signatureBytes) + return nil +} + +func (s *S3Request) calculateCanonicalRequestHash() error { + canonicalRequestLines := []string{s.Config.Method} + + s.path = encodeS3Key(s.path) + canonicalRequestLines = append(canonicalRequestLines, s.path) + + canonicalQuery, err := s.getCanonicalQuery() + if err != nil { + return fmt.Errorf("error generating canoncial query: %w", err) + } + canonicalRequestLines = append(canonicalRequestLines, canonicalQuery) + + var signedParams []string + for _, headerValue := range s.headerValues { + if headerValue.Signed { + key := strings.ToLower(headerValue.Key) + canonicalRequestLines = append(canonicalRequestLines, key+":"+headerValue.Value) + signedParams = append(signedParams, key) + } + } + + canonicalRequestLines = append(canonicalRequestLines, "") + s.signedParamString = strings.Join(signedParams, ";") + canonicalRequestLines = append(canonicalRequestLines, s.signedParamString, s.payloadHash) + + canonicalRequestString := strings.Join(canonicalRequestLines, "\n") + logger.PrintDebug("Canonical request string: %s\n", canonicalRequestString) + + canonicalRequestHashBytes := sha256.Sum256([]byte(canonicalRequestString)) + s.canonicalRequestHash = hex.EncodeToString(canonicalRequestHashBytes[:]) + return nil +} + +func (s *S3Request) getCanonicalQuery() (string, error) { + var queryRequestLine string + if strings.Contains(s.Config.Query, "&") { + queries := strings.Split(s.Config.Query, "&") + if !strings.HasSuffix(queries[0], "=") && !strings.Contains(queries[0], "=") { + queries[0] += "=" + queryRequestLine = strings.Join(queries, "&") + } + } else if s.Config.Query != "" && !strings.HasSuffix(s.Config.Query, "=") && !strings.Contains(s.Config.Query, "=") { + queryRequestLine = s.Config.Query + "=" + } + if queryRequestLine == "" { + queryRequestLine = s.Config.Query + } + canonicalQuery, err := canonicalizeQuery(queryRequestLine) + if err != nil { + return "", fmt.Errorf("error parsing query '%s': %v", queryRequestLine, err) + } + return canonicalQuery, nil +} + +func (s *S3Request) performBasePayloadCalculations() error { + if s.Config.PayloadFile != "" { + s.dataSource = NewFileDataSource(s.Config.PayloadFile) + } else if s.Config.Payload != "" { + s.dataSource = NewStringDataSource(s.Config.Payload) + } + if s.Config.CustomSHA256Hash != "" { + s.payloadHash = s.Config.CustomSHA256Hash + } else if s.Config.PayloadType != "" { + s.payloadHash = s.Config.PayloadType + } else if s.dataSource != nil { + var err error + s.payloadHash, err = s.dataSource.CalculateSHA256HashString() + if err != nil { + return fmt.Errorf("error calculating sha256 hash: %w", err) + } + } else { + s.payloadHash = SHA256HashZeroBytes + } + return nil +} + +func (s *S3Request) deriveUniversalHeaderValues() { + if s.Config.MissingHostParam { + s.headerValues = append(s.headerValues, &HeaderValue{"host", "", true}) + } else if s.Config.CustomHostParamSet { + s.headerValues = append(s.headerValues, &HeaderValue{"host", s.Config.CustomHostParam, true}) + } else { + s.headerValues = append(s.headerValues, &HeaderValue{"host", s.host, true}) + } + if !s.Config.OmitSHA256Hash { + s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-content-sha256", s.payloadHash, true}) + } + if !s.Config.OmitDate { + s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-date", s.currentDateTime, true}) + } +} + +func (s *S3Request) deriveConfigSpecificHeaderValues() error { + if s.Config.PayloadType == StreamingAWS4HMACSHA256PayloadTrailer && s.Config.ChecksumType != "" { + s.headerValues = append(s.headerValues, &HeaderValue{"x-amz-trailer", fmt.Sprintf("x-amz-checksum-%s", s.Config.ChecksumType), true}) + } + if s.dataSource != nil && s.Config.PayloadType != UnsignedPayload { + payloadSize, err := s.dataSource.SourceDataByteSize() + if err != nil { + return fmt.Errorf("error getting payload size: %w", err) + } + s.headerValues = append(s.headerValues, + &HeaderValue{"x-amz-decoded-content-length", fmt.Sprintf("%d", payloadSize), true}) + } + for key, value := range s.Config.SignedParams { + s.headerValues = append(s.headerValues, &HeaderValue{key, value, true}) + } + if s.Config.ContentMD5 || s.Config.IncorrectContentMD5 || s.Config.CustomContentMD5 != "" { + if err := s.addContentMD5Header(); err != nil { + return fmt.Errorf("error adding Content-MD5 header: %w", err) + } + } + for key, value := range s.Config.UnsignedParams { + s.headerValues = append(s.headerValues, &HeaderValue{key, value, false}) + } + sort.Slice(s.headerValues, + func(i, j int) bool { + return strings.ToLower(s.headerValues[i].Key) < strings.ToLower(s.headerValues[j].Key) + }) + return nil +} + +func (s *S3Request) addContentMD5Header() error { + var payloadData []byte + var err error + if s.Config.PayloadFile != "" { + if payloadData, err = os.ReadFile(s.Config.PayloadFile); err != nil { + return fmt.Errorf("error reading file %s: %w", s.Config.PayloadFile, err) + } + } else { + logger.PrintDebug("Payload: %s", s.Config.Payload) + payloadData = []byte(strings.Replace(s.Config.Payload, "\\", "", -1)) + } + + var contentMD5 string + if s.Config.CustomContentMD5 != "" { + contentMD5 = s.Config.CustomContentMD5 + } else { + hasher := md5.New() + hasher.Write(payloadData) + md5Hash := hasher.Sum(nil) + if s.Config.IncorrectContentMD5 { + modifyHash(md5Hash) + } + contentMD5 = base64.StdEncoding.EncodeToString(md5Hash) + } + + s.headerValues = append(s.headerValues, &HeaderValue{"Content-MD5", contentMD5, true}) + return nil +} + +func (s *S3Request) buildAuthorizationString() string { + var credentialString string + if s.Config.IncorrectCredential == "" { + credentialString = fmt.Sprintf("%s/%s/%s/%s/aws4_request", s.Config.AwsAccessKeyId, s.yearMonthDay, s.Config.AwsRegion, s.Config.ServiceName) + } else { + credentialString = s.Config.IncorrectCredential + } + return fmt.Sprintf("Authorization: %s Credential=%s,SignedHeaders=%s,Signature=%s", + s.Config.AuthorizationScheme, credentialString, s.signedParamString, s.signature) +} diff --git a/tests/rest_scripts/command/s3RequestBuilder.go b/tests/rest_scripts/command/s3RequestBuilder.go new file mode 100644 index 00000000..7954f9a3 --- /dev/null +++ b/tests/rest_scripts/command/s3RequestBuilder.go @@ -0,0 +1,169 @@ +package command + +import ( + "crypto/hmac" + "crypto/sha256" + "fmt" + "strings" +) + +const ( + CURL = "curl" + OPENSSL = "openssl" +) + +const ( + UnsignedPayload = "UNSIGNED-PAYLOAD" + StreamingAWS4HMACSHA256Payload = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" + StreamingAWS4HMACSHA256PayloadTrailer = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER" + StreamingUnsignedPayloadTrailer = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" + StreamingAWS4ECDSAP256SHA256Payload = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD" + StreamingAWS4ECDSAP256SHA256PayloadTrailer = "STREAMING-AWS4-ECDSA-P256-SHA256-PAYLOAD-TRAILER" +) + +type PayloadType string + +const ( + ChecksumCRC32 = "crc32" + ChecksumCRC32C = "crc32c" + ChecksumCRC64NVME = "crc64nvme" + ChecksumSHA1 = "sha1" + ChecksumSHA256 = "sha256" +) + +const SHA256HashZeroBytes = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + +type HeaderValue struct { + Key string + Value string + Signed bool +} + +type S3RequestConfigData struct { + Client string + Method string + Url string + BucketName string + ObjectKey string + Query string + AwsRegion string + AwsAccessKeyId string + AwsSecretAccessKey string + ServiceName string + SignedParams map[string]string + UnsignedParams map[string]string + PayloadFile string + IncorrectSignature bool + AuthorizationHeaderMalformed bool + AuthorizationScheme string + IncorrectCredential string + IncorrectYearMonthDay bool + InvalidYearMonthDay bool + Payload string + ContentMD5 bool + IncorrectContentMD5 bool + CustomContentMD5 string + MissingHostParam bool + FilePath string + CustomHostParam string + CustomHostParamSet bool + PayloadType string + ChunkSize int + ChecksumType string + OmitPayloadTrailer bool + OmitPayloadTrailerKey bool + OmitContentLength bool + OmitSHA256Hash bool + CustomSHA256Hash string + OmitDate bool + CustomDate string + WriteXMLPayloadToFile string +} + +type S3RequestBuilder struct { + Config *S3RequestConfigData +} + +func (s *S3RequestBuilder) OpenSSLCommand() error { + openSSLCommand := &OpenSSLCommand{ + S3Request: &S3Request{ + Config: s.Config, + }, + } + if err := s.RenderCommand(openSSLCommand); err != nil { + return fmt.Errorf("error rendering OpenSSL command: %w", err) + } + return nil +} + +func (s *S3RequestBuilder) CurlShellCommand() (string, error) { + curlCommand := &CurlCommand{ + S3Request: &S3Request{ + Config: s.Config, + }, + } + if err := s.RenderCommand(curlCommand); err != nil { + return "", fmt.Errorf("error rendering curl command: %w", err) + } + return curlCommand.String(), nil +} + +func (s *S3RequestBuilder) RenderCommand(renderer Renderer) error { + renderer.CalculateDateTimeParams() + if err := renderer.DeriveHost(); err != nil { + return fmt.Errorf("error deriving host: %w", err) + } + renderer.DeriveBucketAndKeyPath() + if err := renderer.PerformPayloadCalculations(); err != nil { + return fmt.Errorf("error performing payload calculations: %w", err) + } + if err := renderer.DeriveHeaderValues(); err != nil { + return fmt.Errorf("error deriving header values: %w", err) + } + if err := renderer.CalculateSignature(); err != nil { + return fmt.Errorf("error calculating signature: %w", err) + } + if err := renderer.Render(); err != nil { + return fmt.Errorf("error rendering command: %w", err) + } + return nil +} + +func encodeS3Key(key string) string { + parts := strings.Split(key, "/") + for i, p := range parts { + parts[i] = awsEscapePath(p) + } + return strings.Join(parts, "/") +} + +func awsEscapePath(key string) string { + var b strings.Builder + b.Grow(len(key)) + for i := 0; i < len(key); i++ { + c := key[i] + if (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '_' || c == '.' || c == '~' || c == '/' { + b.WriteByte(c) + continue + } + fmt.Fprintf(&b, "%%%02X", c) + } + return b.String() +} + +func modifyHash(md5Hash []byte) { + if md5Hash[0] == 'a' { + md5Hash[0] = 'A' + } else { + md5Hash[0] = 'a' + } +} + +func hmacSHA256(key []byte, data string) []byte { + h := hmac.New(sha256.New, key) + h.Write([]byte(data)) + return h.Sum(nil) +} diff --git a/tests/rest_scripts/config/config.go b/tests/rest_scripts/config/config.go deleted file mode 100644 index d912156b..00000000 --- a/tests/rest_scripts/config/config.go +++ /dev/null @@ -1 +0,0 @@ -package config diff --git a/tests/rest_scripts/generateCommand.go b/tests/rest_scripts/generateCommand.go index 44de1209..9cce417d 100644 --- a/tests/rest_scripts/generateCommand.go +++ b/tests/rest_scripts/generateCommand.go @@ -75,6 +75,8 @@ type restParams map[string]string var paramSeparator *string +var writeXMLPayloadToFile *string + func (r *restParams) String() string { return fmt.Sprintf("%v", *r) } @@ -110,43 +112,46 @@ func main() { log.Fatalf("Error validating config: %v", err) } - baseCommand := &command.S3Command{ - Method: *method, - Url: *url, - BucketName: *bucketName, - ObjectKey: *objectKey, - Query: *query, - AwsRegion: *awsRegion, - AwsAccessKeyId: *awsAccessKeyId, - AwsSecretAccessKey: *awsSecretAccessKey, - ServiceName: *serviceName, - SignedParams: signedParamsMap, - UnsignedParams: unsignedParamsMap, - PayloadFile: *payloadFile, - IncorrectSignature: *incorrectSignature, - AuthorizationScheme: *authorizationScheme, - IncorrectCredential: *incorrectCredential, - IncorrectYearMonthDay: *incorrectYearMonthDay, - InvalidYearMonthDay: *invalidYearMonthDay, - Payload: *payload, - ContentMD5: *contentMD5, - IncorrectContentMD5: *incorrectContentMD5, - CustomContentMD5: *customContentMD5, - MissingHostParam: *missingHostParam, - FilePath: *filePath, - CustomHostParam: *customHostParam, - CustomHostParamSet: customHostParamSet, - PayloadType: *payloadType, - ChunkSize: *chunkSize, - ChecksumType: *checksumType, - OmitPayloadTrailer: *omitPayloadTrailer, - OmitPayloadTrailerKey: *omitPayloadTrailerKey, - OmitContentLength: *omitContentLength, - OmitSHA256Hash: *omitSHA256Hash, - CustomSHA256Hash: *customSHA256Hash, - Client: *client, - OmitDate: *omitDate, - CustomDate: *customDate, + baseCommand := &command.S3RequestBuilder{ + Config: &command.S3RequestConfigData{ + Method: *method, + Url: *url, + BucketName: *bucketName, + ObjectKey: *objectKey, + Query: *query, + AwsRegion: *awsRegion, + AwsAccessKeyId: *awsAccessKeyId, + AwsSecretAccessKey: *awsSecretAccessKey, + ServiceName: *serviceName, + SignedParams: signedParamsMap, + UnsignedParams: unsignedParamsMap, + PayloadFile: *payloadFile, + IncorrectSignature: *incorrectSignature, + AuthorizationScheme: *authorizationScheme, + IncorrectCredential: *incorrectCredential, + IncorrectYearMonthDay: *incorrectYearMonthDay, + InvalidYearMonthDay: *invalidYearMonthDay, + Payload: *payload, + ContentMD5: *contentMD5, + IncorrectContentMD5: *incorrectContentMD5, + CustomContentMD5: *customContentMD5, + MissingHostParam: *missingHostParam, + FilePath: *filePath, + CustomHostParam: *customHostParam, + CustomHostParamSet: customHostParamSet, + PayloadType: *payloadType, + ChunkSize: *chunkSize, + ChecksumType: *checksumType, + OmitPayloadTrailer: *omitPayloadTrailer, + OmitPayloadTrailerKey: *omitPayloadTrailerKey, + OmitContentLength: *omitContentLength, + OmitSHA256Hash: *omitSHA256Hash, + CustomSHA256Hash: *customSHA256Hash, + Client: *client, + OmitDate: *omitDate, + CustomDate: *customDate, + WriteXMLPayloadToFile: *writeXMLPayloadToFile, + }, } s3Command, err := getS3CommandType(baseCommand) @@ -158,7 +163,7 @@ func main() { } } -func getS3CommandType(baseCommand *command.S3Command) (command.S3CommandConverter, error) { +func getS3CommandType(baseCommand *command.S3RequestBuilder) (command.S3CommandConverter, error) { var s3Command command.S3CommandConverter var err error switch *commandType { @@ -261,6 +266,7 @@ func checkFlags() error { flag.Var(&tagValues, "tagValue", "Tag value (can add multiple)") locationConstraint = flag.String("locationConstraint", "", "Location constraint for bucket creation") flag.Var(&corsRules, "corsRule", "CORS rule for PutBucketCORS command (can add multiple)") + writeXMLPayloadToFile = flag.String("writeXMLPayloadToFile", "", "for curl commands, file to write XML payloads to") // Parse the flags flag.Parse()