mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-04 20:24:26 +00:00
Backfill unit tests for audit logging from the CLI
This commit is contained in:
@@ -358,35 +358,49 @@ type nopCache struct{}
|
||||
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
|
||||
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
|
||||
|
||||
type auditIDLoggerFunc func(path string, statusCode int, auditID string)
|
||||
|
||||
func logFailedRequest(path string, statusCode int, auditID string) {
|
||||
plog.Info("Received auditID for failed request",
|
||||
"path", path,
|
||||
"statusCode", statusCode,
|
||||
"auditID", auditID)
|
||||
}
|
||||
|
||||
// maybePrintAuditID will choose to log the auditID when certain failure cases are detected,
|
||||
// to give a breadcrumb for an admin to follow.
|
||||
// Older Supervisors and other OIDC identity providers may not provide this header.
|
||||
func maybePrintAuditID(rt http.RoundTripper) http.RoundTripper {
|
||||
func maybePrintAuditID(rt http.RoundTripper, logFunc auditIDLoggerFunc) http.RoundTripper {
|
||||
return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) {
|
||||
path := r.URL.Path
|
||||
response, responseErr := rt.RoundTrip(r)
|
||||
if response != nil && response.Header.Get("audit-ID") != "" {
|
||||
switch {
|
||||
case response.StatusCode >= http.StatusMultipleChoices && response.StatusCode < http.StatusBadRequest:
|
||||
// failing oauth2/authorize redirects from audit-enabled Supervisors
|
||||
|
||||
location, err := url.Parse(response.Header.Get(httpLocationHeaderName))
|
||||
if err != nil || location.Query().Get("error") != "" {
|
||||
plog.Info("Received auditID for failed request",
|
||||
"path", path,
|
||||
"statusCode", response.StatusCode,
|
||||
"auditID", response.Header.Get("audit-ID"))
|
||||
}
|
||||
case response.StatusCode >= http.StatusBadRequest:
|
||||
// failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors
|
||||
if response == nil ||
|
||||
responseErr != nil ||
|
||||
response.Header.Get("audit-ID") == "" ||
|
||||
response.Request == nil ||
|
||||
response.Request.URL == nil {
|
||||
return response, responseErr
|
||||
}
|
||||
|
||||
plog.Info("Received auditID for failed request",
|
||||
"path", path,
|
||||
"statusCode", response.StatusCode,
|
||||
"auditID", response.Header.Get("audit-ID"))
|
||||
default:
|
||||
// noop
|
||||
auditID := response.Header.Get("audit-ID")
|
||||
// Use the request from the response in case other round-trippers modified the request
|
||||
path := response.Request.URL.Path
|
||||
|
||||
switch statusCode := response.StatusCode; {
|
||||
case statusCode < http.StatusMultipleChoices: // (-inf,300)
|
||||
break // noop
|
||||
case response.StatusCode < http.StatusBadRequest: // [300,400)
|
||||
// Rejected oauth2/authorize redirects from audit-enabled Supervisors will ALWAYS include
|
||||
// the "error" parameter since it is required.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 for more details.
|
||||
location, err := url.Parse(response.Header.Get(httpLocationHeaderName))
|
||||
if err != nil || location == nil || location.Query().Get("error") == "" {
|
||||
break
|
||||
}
|
||||
logFunc(path, statusCode, auditID)
|
||||
default: // [400,inf)
|
||||
// failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors.
|
||||
logFunc(path, statusCode, auditID)
|
||||
}
|
||||
return response, responseErr
|
||||
})
|
||||
@@ -436,7 +450,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
||||
}
|
||||
}
|
||||
|
||||
h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport)
|
||||
h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport, logFailedRequest)
|
||||
|
||||
if h.cliToSendCredentials {
|
||||
if h.loginFlow != "" {
|
||||
|
||||
@@ -3911,3 +3911,154 @@ func TestLoggers(t *testing.T) {
|
||||
|
||||
// NOTE: We can't really test logs with the default (e.g. no logger option specified)
|
||||
}
|
||||
|
||||
func TestMaybePrintAuditID(t *testing.T) {
|
||||
canonicalAuditIdHeaderName := "Audit-Id"
|
||||
|
||||
buildResponse := func(statusCode int) *http.Response {
|
||||
return &http.Response{
|
||||
Header: http.Header{
|
||||
canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"},
|
||||
},
|
||||
StatusCode: statusCode,
|
||||
Request: &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "some-path-from-response-request",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
response *http.Response
|
||||
responseErr error
|
||||
want func(t *testing.T, called func()) auditIDLoggerFunc
|
||||
wantCalled bool
|
||||
}{
|
||||
{
|
||||
name: "happy HTTP response - no error",
|
||||
response: buildResponse(http.StatusOK), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(_ string, _ int, _ string) {
|
||||
called()
|
||||
}
|
||||
},
|
||||
wantCalled: false, // make it obvious
|
||||
},
|
||||
{
|
||||
name: "HTTP response with no response.request.url will not log",
|
||||
response: func() *http.Response {
|
||||
response := buildResponse(http.StatusOK)
|
||||
response.Request.URL = nil
|
||||
return response
|
||||
}(), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(_ string, _ int, _ string) {
|
||||
called()
|
||||
}
|
||||
},
|
||||
wantCalled: false, // make it obvious
|
||||
},
|
||||
{
|
||||
name: "302 with error parameter in location and audit-ID will log",
|
||||
response: func() *http.Response {
|
||||
response := buildResponse(http.StatusFound)
|
||||
response.Header.Set("Location", "https://example.com?error=some-error")
|
||||
return response
|
||||
}(), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(path string, statusCode int, auditID string) {
|
||||
called()
|
||||
require.Equal(t, "some-path-from-response-request", path)
|
||||
require.Equal(t, http.StatusFound, statusCode)
|
||||
require.Equal(t, "some-audit-id", auditID)
|
||||
}
|
||||
},
|
||||
wantCalled: true,
|
||||
},
|
||||
{
|
||||
name: "303 with error parameter in location and audit-ID will log",
|
||||
response: func() *http.Response {
|
||||
response := buildResponse(http.StatusSeeOther)
|
||||
response.Header.Set("Location", "https://example.com?error=some-error")
|
||||
return response
|
||||
}(), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(path string, statusCode int, auditID string) {
|
||||
called()
|
||||
require.Equal(t, "some-path-from-response-request", path)
|
||||
require.Equal(t, http.StatusSeeOther, statusCode)
|
||||
require.Equal(t, "some-audit-id", auditID)
|
||||
}
|
||||
},
|
||||
wantCalled: true,
|
||||
},
|
||||
{
|
||||
name: "303 without error parameter in location and audit-ID will not log",
|
||||
response: func() *http.Response {
|
||||
response := buildResponse(http.StatusSeeOther)
|
||||
response.Header.Set("Location", "https://example.com?foo=bar")
|
||||
return response
|
||||
}(), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(path string, statusCode int, auditID string) {
|
||||
called()
|
||||
}
|
||||
},
|
||||
wantCalled: false, // make it obvious
|
||||
},
|
||||
{
|
||||
name: "404 with error parameter in location and audit-ID will log",
|
||||
response: buildResponse(http.StatusNotFound), //nolint:bodyclose // there is no Body.
|
||||
responseErr: nil,
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(path string, statusCode int, auditID string) {
|
||||
called()
|
||||
require.Equal(t, "some-path-from-response-request", path)
|
||||
require.Equal(t, http.StatusNotFound, statusCode)
|
||||
require.Equal(t, "some-audit-id", auditID)
|
||||
}
|
||||
},
|
||||
wantCalled: true,
|
||||
},
|
||||
{
|
||||
name: "when the roundtrip returns an error, will not log",
|
||||
responseErr: errors.New("some error"),
|
||||
want: func(t *testing.T, called func()) auditIDLoggerFunc {
|
||||
return func(path string, statusCode int, auditID string) {
|
||||
called()
|
||||
}
|
||||
},
|
||||
wantCalled: false, // make it obvious
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
require.NotNil(t, test.want)
|
||||
|
||||
mockRequest := &http.Request{
|
||||
URL: &url.URL{
|
||||
Path: "should-never-use-this-path",
|
||||
},
|
||||
}
|
||||
var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) {
|
||||
require.Equal(t, mockRequest, r)
|
||||
return test.response, test.responseErr
|
||||
}
|
||||
called := false
|
||||
subjectRt := maybePrintAuditID(mockRt, test.want(t, func() {
|
||||
called = true
|
||||
}))
|
||||
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
|
||||
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
|
||||
require.Equal(t, test.response, actualResponse)
|
||||
require.Equal(t, test.wantCalled, called, "expected logFunc to be called")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user