mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-03 11:45:45 +00:00
Backfill unit tests for cmd/pinniped/cmd/audit_id.go
This commit is contained in:
@@ -10,17 +10,37 @@ import (
|
|||||||
"go.pinniped.dev/internal/plog"
|
"go.pinniped.dev/internal/plog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type auditIDLoggerFunc func(path string, statusCode int, auditID string)
|
||||||
|
|
||||||
|
func logAuditID(path string, statusCode int, auditID string) {
|
||||||
|
plog.Info("Received auditID for failed request",
|
||||||
|
"path", path,
|
||||||
|
"statusCode", statusCode,
|
||||||
|
"auditID", auditID)
|
||||||
|
}
|
||||||
|
|
||||||
func LogAuditIDTransportWrapper(rt http.RoundTripper) http.RoundTripper {
|
func LogAuditIDTransportWrapper(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return logAuditIDTransportWrapper(rt, logAuditID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logAuditIDTransportWrapper(rt http.RoundTripper, auditIDLoggerFunc auditIDLoggerFunc) http.RoundTripper {
|
||||||
return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) {
|
return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) {
|
||||||
response, responseErr := rt.RoundTrip(r)
|
response, responseErr := rt.RoundTrip(r)
|
||||||
if response != nil && response.Header.Get("audit-ID") != "" {
|
|
||||||
plog.Info("Received auditID for request",
|
if responseErr != nil ||
|
||||||
// Use the request path from the response's request, in case the
|
response == nil ||
|
||||||
// original request was modified by any other roudtrippers in the chain.
|
response.Header.Get("audit-ID") == "" ||
|
||||||
"path", response.Request.URL.Path,
|
response.Request == nil ||
|
||||||
"statusCode", response.StatusCode,
|
response.Request.URL == nil {
|
||||||
"auditID", response.Header.Get("audit-ID"))
|
return response, responseErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use the request path from the response's request, in case the
|
||||||
|
// original request was modified by any other roudtrippers in the chain.
|
||||||
|
auditIDLoggerFunc(response.Request.URL.Path,
|
||||||
|
response.StatusCode,
|
||||||
|
response.Header.Get("audit-ID"))
|
||||||
|
|
||||||
return response, responseErr
|
return response, responseErr
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
116
cmd/pinniped/cmd/audit_id_test.go
Normal file
116
cmd/pinniped/cmd/audit_id_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.pinniped.dev/internal/httputil/roundtripper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogAuditIDTransportWrapper(t *testing.T) {
|
||||||
|
canonicalAuditIdHeaderName := "Audit-Id"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response *http.Response
|
||||||
|
responseErr error
|
||||||
|
want func(t *testing.T, called func()) auditIDLoggerFunc
|
||||||
|
wantCalled bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "happy HTTP response - no error and no log",
|
||||||
|
response: &http.Response{ // no headers
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Request: &http.Request{
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "some-path-from-response-request",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
responseErr: nil,
|
||||||
|
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||||
|
return func(_ string, _ int, _ string) {
|
||||||
|
called()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantCalled: false, // make it obvious
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil HTTP response - no error and no log",
|
||||||
|
response: nil,
|
||||||
|
responseErr: nil,
|
||||||
|
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||||
|
return func(_ string, _ int, _ string) {
|
||||||
|
called()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantCalled: false, // make it obvious
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "err HTTP response - no error and no log",
|
||||||
|
response: nil,
|
||||||
|
responseErr: errors.New("some error"),
|
||||||
|
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||||
|
return func(_ string, _ int, _ string) {
|
||||||
|
called()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantCalled: false, // make it obvious
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "happy HTTP response with audit-ID - logs",
|
||||||
|
response: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"},
|
||||||
|
},
|
||||||
|
StatusCode: http.StatusBadGateway, // statusCode does not matter
|
||||||
|
Request: &http.Request{
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "some-path-from-response-request",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||||
|
return func(path string, statusCode int, auditID string) {
|
||||||
|
called()
|
||||||
|
require.Equal(t, "some-path-from-response-request", path)
|
||||||
|
require.Equal(t, http.StatusBadGateway, statusCode)
|
||||||
|
require.Equal(t, "some-audit-id", auditID)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantCalled: true, // make it obvious
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
require.NotNil(t, test.want)
|
||||||
|
|
||||||
|
mockRequest := &http.Request{
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "should-never-use-this-path",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) {
|
||||||
|
require.Equal(t, mockRequest, r)
|
||||||
|
return test.response, test.responseErr
|
||||||
|
}
|
||||||
|
called := false
|
||||||
|
subjectRt := logAuditIDTransportWrapper(mockRt, test.want(t, func() {
|
||||||
|
called = true
|
||||||
|
}))
|
||||||
|
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
|
||||||
|
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
|
||||||
|
require.Equal(t, test.response, actualResponse)
|
||||||
|
require.Equal(t, test.wantCalled, called,
|
||||||
|
"want logFunc to be called: %t, actually was called: %t", test.wantCalled, called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4058,7 +4058,8 @@ func TestMaybePrintAuditID(t *testing.T) {
|
|||||||
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
|
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
|
||||||
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
|
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
|
||||||
require.Equal(t, test.response, actualResponse)
|
require.Equal(t, test.response, actualResponse)
|
||||||
require.Equal(t, test.wantCalled, called, "expected logFunc to be called")
|
require.Equal(t, test.wantCalled, called,
|
||||||
|
"want logFunc to be called: %t, actually was called: %t", test.wantCalled, called)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user