From 09ca7920ea8df8a002a2f08663dee051a8431951 Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Fri, 1 Nov 2024 13:55:29 -0500 Subject: [PATCH] Extract testutil helper function --- .../endpoints/callback/callback_handler_test.go | 10 +--------- .../endpoints/login/login_handler_test.go | 10 +--------- internal/testutil/log_lines.go | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/internal/federationdomain/endpoints/callback/callback_handler_test.go b/internal/federationdomain/endpoints/callback/callback_handler_test.go index 5950f798b..19ae26a39 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_test.go @@ -1955,15 +1955,7 @@ func TestCallbackEndpoint(t *testing.T) { } if test.wantAuditLogs != nil { - var encodedStateParam stateparam.Encoded - if test.path != "" { - var path *url.URL - path, err = url.Parse(test.path) - require.NoError(t, err) - encodedStateParam = stateparam.Encoded(path.Query().Get("state")) - } - - wantAuditLogs := test.wantAuditLogs(encodedStateParam, sessionID) + wantAuditLogs := test.wantAuditLogs(testutil.GetStateParam(t, test.path), sessionID) testutil.CompareAuditLogs(t, wantAuditLogs, log.String()) } }) diff --git a/internal/federationdomain/endpoints/login/login_handler_test.go b/internal/federationdomain/endpoints/login/login_handler_test.go index 5d3aceb40..776b93ca5 100644 --- a/internal/federationdomain/endpoints/login/login_handler_test.go +++ b/internal/federationdomain/endpoints/login/login_handler_test.go @@ -464,15 +464,7 @@ func TestLoginEndpoint(t *testing.T) { require.Equal(t, test.wantBody, rsp.Body.String()) if test.wantAuditLogs != nil { - var encodedStateParam stateparam.Encoded - if test.path != "" { - var path *url.URL - path, err = url.Parse(test.path) - require.NoError(t, err) - encodedStateParam = stateparam.Encoded(path.Query().Get("state")) - } - - wantAuditLogs := test.wantAuditLogs(encodedStateParam) + wantAuditLogs := test.wantAuditLogs(testutil.GetStateParam(t, test.path)) testutil.CompareAuditLogs(t, wantAuditLogs, log.String()) } }) diff --git a/internal/testutil/log_lines.go b/internal/testutil/log_lines.go index 02ce4f243..d60ae0bbf 100644 --- a/internal/testutil/log_lines.go +++ b/internal/testutil/log_lines.go @@ -6,10 +6,13 @@ package testutil import ( "bytes" "encoding/json" + "net/url" "strings" "testing" "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/federationdomain/stateparam" ) func RequireLogLines(t *testing.T, wantLogs []string, log *bytes.Buffer) { @@ -38,6 +41,17 @@ func WantAuditLog(message string, params map[string]any, auditID ...string) Want return result } +func GetStateParam(t *testing.T, fullURL string) stateparam.Encoded { + var encodedStateParam stateparam.Encoded + if fullURL != "" { + path, err := url.Parse(fullURL) + require.NoError(t, err) + encodedStateParam = stateparam.Encoded(path.Query().Get("state")) + } + + return encodedStateParam +} + func CompareAuditLogs(t *testing.T, wantAuditLogs []WantedAuditLog, actualAuditLogsOneLiner string) { t.Helper()