Backfill unit tests for audit logging from the CLI

This commit is contained in:
Joshua Casey
2024-11-19 12:06:39 -06:00
parent 6bf9b64778
commit 8dffd60f0b
2 changed files with 187 additions and 22 deletions

View File

@@ -358,35 +358,49 @@ type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
type auditIDLoggerFunc func(path string, statusCode int, auditID string)
func logFailedRequest(path string, statusCode int, auditID string) {
plog.Info("Received auditID for failed request",
"path", path,
"statusCode", statusCode,
"auditID", auditID)
}
// maybePrintAuditID will choose to log the auditID when certain failure cases are detected,
// to give a breadcrumb for an admin to follow.
// Older Supervisors and other OIDC identity providers may not provide this header.
func maybePrintAuditID(rt http.RoundTripper) http.RoundTripper {
func maybePrintAuditID(rt http.RoundTripper, logFunc auditIDLoggerFunc) http.RoundTripper {
return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) {
path := r.URL.Path
response, responseErr := rt.RoundTrip(r)
if response != nil && response.Header.Get("audit-ID") != "" {
switch {
case response.StatusCode >= http.StatusMultipleChoices && response.StatusCode < http.StatusBadRequest:
// failing oauth2/authorize redirects from audit-enabled Supervisors
location, err := url.Parse(response.Header.Get(httpLocationHeaderName))
if err != nil || location.Query().Get("error") != "" {
plog.Info("Received auditID for failed request",
"path", path,
"statusCode", response.StatusCode,
"auditID", response.Header.Get("audit-ID"))
}
case response.StatusCode >= http.StatusBadRequest:
// failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors
if response == nil ||
responseErr != nil ||
response.Header.Get("audit-ID") == "" ||
response.Request == nil ||
response.Request.URL == nil {
return response, responseErr
}
plog.Info("Received auditID for failed request",
"path", path,
"statusCode", response.StatusCode,
"auditID", response.Header.Get("audit-ID"))
default:
// noop
auditID := response.Header.Get("audit-ID")
// Use the request from the response in case other round-trippers modified the request
path := response.Request.URL.Path
switch statusCode := response.StatusCode; {
case statusCode < http.StatusMultipleChoices: // (-inf,300)
break // noop
case response.StatusCode < http.StatusBadRequest: // [300,400)
// Rejected oauth2/authorize redirects from audit-enabled Supervisors will ALWAYS include
// the "error" parameter since it is required.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 for more details.
location, err := url.Parse(response.Header.Get(httpLocationHeaderName))
if err != nil || location == nil || location.Query().Get("error") == "" {
break
}
logFunc(path, statusCode, auditID)
default: // [400,inf)
// failing discovery, oauth2/authorize, or oauth2/token responses from audit-enabled Supervisors.
logFunc(path, statusCode, auditID)
}
return response, responseErr
})
@@ -436,7 +450,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
}
}
h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport)
h.httpClient.Transport = maybePrintAuditID(h.httpClient.Transport, logFailedRequest)
if h.cliToSendCredentials {
if h.loginFlow != "" {

View File

@@ -3911,3 +3911,154 @@ func TestLoggers(t *testing.T) {
// NOTE: We can't really test logs with the default (e.g. no logger option specified)
}
func TestMaybePrintAuditID(t *testing.T) {
canonicalAuditIdHeaderName := "Audit-Id"
buildResponse := func(statusCode int) *http.Response {
return &http.Response{
Header: http.Header{
canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"},
},
StatusCode: statusCode,
Request: &http.Request{
URL: &url.URL{
Path: "some-path-from-response-request",
},
},
}
}
tests := []struct {
name string
response *http.Response
responseErr error
want func(t *testing.T, called func()) auditIDLoggerFunc
wantCalled bool
}{
{
name: "happy HTTP response - no error",
response: buildResponse(http.StatusOK), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(_ string, _ int, _ string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "HTTP response with no response.request.url will not log",
response: func() *http.Response {
response := buildResponse(http.StatusOK)
response.Request.URL = nil
return response
}(), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(_ string, _ int, _ string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "302 with error parameter in location and audit-ID will log",
response: func() *http.Response {
response := buildResponse(http.StatusFound)
response.Header.Set("Location", "https://example.com?error=some-error")
return response
}(), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
require.Equal(t, "some-path-from-response-request", path)
require.Equal(t, http.StatusFound, statusCode)
require.Equal(t, "some-audit-id", auditID)
}
},
wantCalled: true,
},
{
name: "303 with error parameter in location and audit-ID will log",
response: func() *http.Response {
response := buildResponse(http.StatusSeeOther)
response.Header.Set("Location", "https://example.com?error=some-error")
return response
}(), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
require.Equal(t, "some-path-from-response-request", path)
require.Equal(t, http.StatusSeeOther, statusCode)
require.Equal(t, "some-audit-id", auditID)
}
},
wantCalled: true,
},
{
name: "303 without error parameter in location and audit-ID will not log",
response: func() *http.Response {
response := buildResponse(http.StatusSeeOther)
response.Header.Set("Location", "https://example.com?foo=bar")
return response
}(), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "404 with error parameter in location and audit-ID will log",
response: buildResponse(http.StatusNotFound), //nolint:bodyclose // there is no Body.
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
require.Equal(t, "some-path-from-response-request", path)
require.Equal(t, http.StatusNotFound, statusCode)
require.Equal(t, "some-audit-id", auditID)
}
},
wantCalled: true,
},
{
name: "when the roundtrip returns an error, will not log",
responseErr: errors.New("some error"),
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
}
},
wantCalled: false, // make it obvious
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.NotNil(t, test.want)
mockRequest := &http.Request{
URL: &url.URL{
Path: "should-never-use-this-path",
},
}
var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) {
require.Equal(t, mockRequest, r)
return test.response, test.responseErr
}
called := false
subjectRt := maybePrintAuditID(mockRt, test.want(t, func() {
called = true
}))
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
require.Equal(t, test.response, actualResponse)
require.Equal(t, test.wantCalled, called, "expected logFunc to be called")
})
}
}