diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 27f4f7e37..0e5a78ed9 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -358,35 +358,49 @@ type nopCache struct{} func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} +type auditIDLoggerFunc func(path string, statusCode int, auditID string) + +func logFailedRequest(path string, statusCode int, auditID string) { + plog.Info("Received auditID for failed request", + "path", path, + "statusCode", statusCode, + "auditID", auditID) +} + // maybePrintAuditID will choose to log the auditID when certain failure cases are detected, // to give a breadcrumb for an admin to follow. // Older Supervisors and other OIDC identity providers may not provide this header. -func maybePrintAuditID(rt http.RoundTripper) http.RoundTripper { +func maybePrintAuditID(rt http.RoundTripper, logFunc auditIDLoggerFunc) http.RoundTripper { return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) { - path := r.URL.Path response, responseErr := rt.RoundTrip(r) - if response != nil && response.Header.Get("audit-ID") != "" { - switch { - case response.StatusCode >= http.StatusMultipleChoices && response.StatusCode < http.StatusBadRequest: - // failing oauth2/authorize redirects from audit-enabled Supervisors - location, err := url.Parse(response.Header.Get(httpLocationHeaderName)) - if err != nil || location.Query().Get("error") != "" { - plog.Info("Received auditID for failed request", - "path", path, - "statusCode", response.StatusCode, - "auditID", response.Header.Get("audit-ID")) - } - case response.StatusCode >= http.StatusBadRequest: - // failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors + if response == nil || + responseErr != nil || + response.Header.Get("audit-ID") == "" || + response.Request == nil || + response.Request.URL == nil { + return response, responseErr + } - plog.Info("Received auditID for failed request", - "path", path, - "statusCode", response.StatusCode, - "auditID", response.Header.Get("audit-ID")) - default: - // noop + auditID := response.Header.Get("audit-ID") + // Use the request from the response in case other round-trippers modified the request + path := response.Request.URL.Path + + switch statusCode := response.StatusCode; { + case statusCode < http.StatusMultipleChoices: // (-inf,300) + break // noop + case response.StatusCode < http.StatusBadRequest: // [300,400) + // Rejected oauth2/authorize redirects from audit-enabled Supervisors will ALWAYS include + // the "error" parameter since it is required. + // See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 for more details. + location, err := url.Parse(response.Header.Get(httpLocationHeaderName)) + if err != nil || location == nil || location.Query().Get("error") == "" { + break } + logFunc(path, statusCode, auditID) + default: // [400,inf) + // failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors. + logFunc(path, statusCode, auditID) } return response, responseErr }) @@ -436,7 +450,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er } } - h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport) + h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport, logFailedRequest) if h.cliToSendCredentials { if h.loginFlow != "" { diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 8f7abc98f..8ede17dc0 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -3911,3 +3911,154 @@ func TestLoggers(t *testing.T) { // NOTE: We can't really test logs with the default (e.g. no logger option specified) } + +func TestMaybePrintAuditID(t *testing.T) { + canonicalAuditIdHeaderName := "Audit-Id" + + buildResponse := func(statusCode int) *http.Response { + return &http.Response{ + Header: http.Header{ + canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"}, + }, + StatusCode: statusCode, + Request: &http.Request{ + URL: &url.URL{ + Path: "some-path-from-response-request", + }, + }, + } + } + tests := []struct { + name string + response *http.Response + responseErr error + want func(t *testing.T, called func()) auditIDLoggerFunc + wantCalled bool + }{ + { + name: "happy HTTP response - no error", + response: buildResponse(http.StatusOK), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(_ string, _ int, _ string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "HTTP response with no response.request.url will not log", + response: func() *http.Response { + response := buildResponse(http.StatusOK) + response.Request.URL = nil + return response + }(), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(_ string, _ int, _ string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "302 with error parameter in location and audit-ID will log", + response: func() *http.Response { + response := buildResponse(http.StatusFound) + response.Header.Set("Location", "https://example.com?error=some-error") + return response + }(), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + require.Equal(t, "some-path-from-response-request", path) + require.Equal(t, http.StatusFound, statusCode) + require.Equal(t, "some-audit-id", auditID) + } + }, + wantCalled: true, + }, + { + name: "303 with error parameter in location and audit-ID will log", + response: func() *http.Response { + response := buildResponse(http.StatusSeeOther) + response.Header.Set("Location", "https://example.com?error=some-error") + return response + }(), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + require.Equal(t, "some-path-from-response-request", path) + require.Equal(t, http.StatusSeeOther, statusCode) + require.Equal(t, "some-audit-id", auditID) + } + }, + wantCalled: true, + }, + { + name: "303 without error parameter in location and audit-ID will not log", + response: func() *http.Response { + response := buildResponse(http.StatusSeeOther) + response.Header.Set("Location", "https://example.com?foo=bar") + return response + }(), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "404 with error parameter in location and audit-ID will log", + response: buildResponse(http.StatusNotFound), //nolint:bodyclose // there is no Body. + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + require.Equal(t, "some-path-from-response-request", path) + require.Equal(t, http.StatusNotFound, statusCode) + require.Equal(t, "some-audit-id", auditID) + } + }, + wantCalled: true, + }, + { + name: "when the roundtrip returns an error, will not log", + responseErr: errors.New("some error"), + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.NotNil(t, test.want) + + mockRequest := &http.Request{ + URL: &url.URL{ + Path: "should-never-use-this-path", + }, + } + var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) { + require.Equal(t, mockRequest, r) + return test.response, test.responseErr + } + called := false + subjectRt := maybePrintAuditID(mockRt, test.want(t, func() { + called = true + })) + actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body. + require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors. + require.Equal(t, test.response, actualResponse) + require.Equal(t, test.wantCalled, called, "expected logFunc to be called") + }) + } +}