From 51c86795af36ad01e30ab0254dcef3e44b586e4d Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Tue, 19 Nov 2024 13:29:06 -0600 Subject: [PATCH] Backfill unit tests for cmd/pinniped/cmd/audit_id.go --- cmd/pinniped/cmd/audit_id.go | 34 +++++++-- cmd/pinniped/cmd/audit_id_test.go | 116 ++++++++++++++++++++++++++++++ pkg/oidcclient/login_test.go | 3 +- 3 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 cmd/pinniped/cmd/audit_id_test.go diff --git a/cmd/pinniped/cmd/audit_id.go b/cmd/pinniped/cmd/audit_id.go index 591c7a895..51879c95a 100644 --- a/cmd/pinniped/cmd/audit_id.go +++ b/cmd/pinniped/cmd/audit_id.go @@ -10,17 +10,37 @@ import ( "go.pinniped.dev/internal/plog" ) +type auditIDLoggerFunc func(path string, statusCode int, auditID string) + +func logAuditID(path string, statusCode int, auditID string) { + plog.Info("Received auditID for failed request", + "path", path, + "statusCode", statusCode, + "auditID", auditID) +} + func LogAuditIDTransportWrapper(rt http.RoundTripper) http.RoundTripper { + return logAuditIDTransportWrapper(rt, logAuditID) +} + +func logAuditIDTransportWrapper(rt http.RoundTripper, auditIDLoggerFunc auditIDLoggerFunc) http.RoundTripper { return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) { response, responseErr := rt.RoundTrip(r) - if response != nil && response.Header.Get("audit-ID") != "" { - plog.Info("Received auditID for request", - // Use the request path from the response's request, in case the - // original request was modified by any other roudtrippers in the chain. - "path", response.Request.URL.Path, - "statusCode", response.StatusCode, - "auditID", response.Header.Get("audit-ID")) + + if responseErr != nil || + response == nil || + response.Header.Get("audit-ID") == "" || + response.Request == nil || + response.Request.URL == nil { + return response, responseErr } + + // Use the request path from the response's request, in case the + // original request was modified by any other roudtrippers in the chain. + auditIDLoggerFunc(response.Request.URL.Path, + response.StatusCode, + response.Header.Get("audit-ID")) + return response, responseErr }) } diff --git a/cmd/pinniped/cmd/audit_id_test.go b/cmd/pinniped/cmd/audit_id_test.go new file mode 100644 index 000000000..3ac73053b --- /dev/null +++ b/cmd/pinniped/cmd/audit_id_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "errors" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/httputil/roundtripper" +) + +func TestLogAuditIDTransportWrapper(t *testing.T) { + canonicalAuditIdHeaderName := "Audit-Id" + + tests := []struct { + name string + response *http.Response + responseErr error + want func(t *testing.T, called func()) auditIDLoggerFunc + wantCalled bool + }{ + { + name: "happy HTTP response - no error and no log", + response: &http.Response{ // no headers + StatusCode: http.StatusOK, + Request: &http.Request{ + URL: &url.URL{ + Path: "some-path-from-response-request", + }, + }, + }, + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(_ string, _ int, _ string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "nil HTTP response - no error and no log", + response: nil, + responseErr: nil, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(_ string, _ int, _ string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "err HTTP response - no error and no log", + response: nil, + responseErr: errors.New("some error"), + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(_ string, _ int, _ string) { + called() + } + }, + wantCalled: false, // make it obvious + }, + { + name: "happy HTTP response with audit-ID - logs", + response: &http.Response{ + Header: http.Header{ + canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"}, + }, + StatusCode: http.StatusBadGateway, // statusCode does not matter + Request: &http.Request{ + URL: &url.URL{ + Path: "some-path-from-response-request", + }, + }, + }, + want: func(t *testing.T, called func()) auditIDLoggerFunc { + return func(path string, statusCode int, auditID string) { + called() + require.Equal(t, "some-path-from-response-request", path) + require.Equal(t, http.StatusBadGateway, statusCode) + require.Equal(t, "some-audit-id", auditID) + } + }, + wantCalled: true, // make it obvious + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.NotNil(t, test.want) + + mockRequest := &http.Request{ + URL: &url.URL{ + Path: "should-never-use-this-path", + }, + } + var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) { + require.Equal(t, mockRequest, r) + return test.response, test.responseErr + } + called := false + subjectRt := logAuditIDTransportWrapper(mockRt, test.want(t, func() { + called = true + })) + actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body. + require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors. + require.Equal(t, test.response, actualResponse) + require.Equal(t, test.wantCalled, called, + "want logFunc to be called: %t, actually was called: %t", test.wantCalled, called) + }) + } +} diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 8ede17dc0..12242011e 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -4058,7 +4058,8 @@ func TestMaybePrintAuditID(t *testing.T) { actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body. require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors. require.Equal(t, test.response, actualResponse) - require.Equal(t, test.wantCalled, called, "expected logFunc to be called") + require.Equal(t, test.wantCalled, called, + "want logFunc to be called: %t, actually was called: %t", test.wantCalled, called) }) } }