Backfill unit tests for cmd/pinniped/cmd/audit_id.go

This commit is contained in:
Joshua Casey
2024-11-19 13:29:06 -06:00
parent 8dffd60f0b
commit 51c86795af
3 changed files with 145 additions and 8 deletions

View File

@@ -10,17 +10,37 @@ import (
"go.pinniped.dev/internal/plog"
)
type auditIDLoggerFunc func(path string, statusCode int, auditID string)
func logAuditID(path string, statusCode int, auditID string) {
plog.Info("Received auditID for failed request",
"path", path,
"statusCode", statusCode,
"auditID", auditID)
}
func LogAuditIDTransportWrapper(rt http.RoundTripper) http.RoundTripper {
return logAuditIDTransportWrapper(rt, logAuditID)
}
func logAuditIDTransportWrapper(rt http.RoundTripper, auditIDLoggerFunc auditIDLoggerFunc) http.RoundTripper {
return roundtripper.WrapFunc(rt, func(r *http.Request) (*http.Response, error) {
response, responseErr := rt.RoundTrip(r)
if response != nil && response.Header.Get("audit-ID") != "" {
plog.Info("Received auditID for request",
// Use the request path from the response's request, in case the
// original request was modified by any other roudtrippers in the chain.
"path", response.Request.URL.Path,
"statusCode", response.StatusCode,
"auditID", response.Header.Get("audit-ID"))
if responseErr != nil ||
response == nil ||
response.Header.Get("audit-ID") == "" ||
response.Request == nil ||
response.Request.URL == nil {
return response, responseErr
}
// Use the request path from the response's request, in case the
// original request was modified by any other roudtrippers in the chain.
auditIDLoggerFunc(response.Request.URL.Path,
response.StatusCode,
response.Header.Get("audit-ID"))
return response, responseErr
})
}

View File

@@ -0,0 +1,116 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"errors"
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/httputil/roundtripper"
)
func TestLogAuditIDTransportWrapper(t *testing.T) {
canonicalAuditIdHeaderName := "Audit-Id"
tests := []struct {
name string
response *http.Response
responseErr error
want func(t *testing.T, called func()) auditIDLoggerFunc
wantCalled bool
}{
{
name: "happy HTTP response - no error and no log",
response: &http.Response{ // no headers
StatusCode: http.StatusOK,
Request: &http.Request{
URL: &url.URL{
Path: "some-path-from-response-request",
},
},
},
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(_ string, _ int, _ string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "nil HTTP response - no error and no log",
response: nil,
responseErr: nil,
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(_ string, _ int, _ string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "err HTTP response - no error and no log",
response: nil,
responseErr: errors.New("some error"),
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(_ string, _ int, _ string) {
called()
}
},
wantCalled: false, // make it obvious
},
{
name: "happy HTTP response with audit-ID - logs",
response: &http.Response{
Header: http.Header{
canonicalAuditIdHeaderName: []string{"some-audit-id", "some-other-audit-id-that-will-never-be-seen"},
},
StatusCode: http.StatusBadGateway, // statusCode does not matter
Request: &http.Request{
URL: &url.URL{
Path: "some-path-from-response-request",
},
},
},
want: func(t *testing.T, called func()) auditIDLoggerFunc {
return func(path string, statusCode int, auditID string) {
called()
require.Equal(t, "some-path-from-response-request", path)
require.Equal(t, http.StatusBadGateway, statusCode)
require.Equal(t, "some-audit-id", auditID)
}
},
wantCalled: true, // make it obvious
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.NotNil(t, test.want)
mockRequest := &http.Request{
URL: &url.URL{
Path: "should-never-use-this-path",
},
}
var mockRt roundtripper.Func = func(r *http.Request) (*http.Response, error) {
require.Equal(t, mockRequest, r)
return test.response, test.responseErr
}
called := false
subjectRt := logAuditIDTransportWrapper(mockRt, test.want(t, func() {
called = true
}))
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
require.Equal(t, test.response, actualResponse)
require.Equal(t, test.wantCalled, called,
"want logFunc to be called: %t, actually was called: %t", test.wantCalled, called)
})
}
}

View File

@@ -4058,7 +4058,8 @@ func TestMaybePrintAuditID(t *testing.T) {
actualResponse, err := subjectRt.RoundTrip(mockRequest) //nolint:bodyclose // there is no Body.
require.Equal(t, test.responseErr, err) // This roundtripper only returns mocked errors.
require.Equal(t, test.response, actualResponse)
require.Equal(t, test.wantCalled, called, "expected logFunc to be called")
require.Equal(t, test.wantCalled, called,
"want logFunc to be called: %t, actually was called: %t", test.wantCalled, called)
})
}
}