diff --git a/internal/federationdomain/endpoints/login/login_handler_test.go b/internal/federationdomain/endpoints/login/login_handler_test.go index cf187192c..5d3aceb40 100644 --- a/internal/federationdomain/endpoints/login/login_handler_test.go +++ b/internal/federationdomain/endpoints/login/login_handler_test.go @@ -122,6 +122,7 @@ func TestLoginEndpoint(t *testing.T) { wantBody string wantEncodedState stateparam.Encoded wantDecodedState *oidc.UpstreamStateParamData + wantAuditLogs func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog }{ { name: "PUT method is invalid", @@ -131,6 +132,9 @@ func TestLoginEndpoint(t *testing.T) { wantStatus: http.StatusMethodNotAllowed, wantContentType: htmlContentType, wantBody: "Method Not Allowed: PUT (try GET or POST)\n", + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{} + }, }, { name: "PATCH method is invalid", @@ -194,6 +198,9 @@ func TestLoginEndpoint(t *testing.T) { wantStatus: http.StatusBadRequest, wantContentType: htmlContentType, wantBody: "Bad Request: state param not found\n", + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{} + }, }, { name: "state param was not included on POST request", @@ -332,6 +339,13 @@ func TestLoginEndpoint(t *testing.T) { wantBody: happyGetResult, wantEncodedState: happyState, wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ + "authorizeID": encodedStateParam.AuthorizeID(), + }), + } + }, }, { name: "happy POST request for LDAP upstream", @@ -343,6 +357,13 @@ func TestLoginEndpoint(t *testing.T) { wantBody: happyPostResult, wantEncodedState: happyState, wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ + "authorizeID": encodedStateParam.AuthorizeID(), + }), + } + }, }, { name: "happy GET request for ActiveDirectory upstream", @@ -354,6 +375,13 @@ func TestLoginEndpoint(t *testing.T) { wantBody: happyGetResult, wantEncodedState: happyActiveDirectoryState, wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ + "authorizeID": encodedStateParam.AuthorizeID(), + }), + } + }, }, { name: "happy POST request for ActiveDirectory upstream", @@ -365,6 +393,13 @@ func TestLoginEndpoint(t *testing.T) { wantBody: happyPostResult, wantEncodedState: happyActiveDirectoryState, wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ + "authorizeID": encodedStateParam.AuthorizeID(), + }), + } + }, }, } @@ -412,7 +447,9 @@ func TestLoginEndpoint(t *testing.T) { return test.postHandlerErr } - subject := NewHandler(happyStateCodec, happyCookieCodec, testGetHandler, testPostHandler, plog.New()) + logger, log := plog.TestLogger(t) + + subject := NewHandler(happyStateCodec, happyCookieCodec, testGetHandler, testPostHandler, logger) subject.ServeHTTP(rsp, req) @@ -425,6 +462,19 @@ func TestLoginEndpoint(t *testing.T) { require.Equal(t, test.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) require.Equal(t, test.wantBody, rsp.Body.String()) + + if test.wantAuditLogs != nil { + var encodedStateParam stateparam.Encoded + if test.path != "" { + var path *url.URL + path, err = url.Parse(test.path) + require.NoError(t, err) + encodedStateParam = stateparam.Encoded(path.Query().Get("state")) + } + + wantAuditLogs := test.wantAuditLogs(encodedStateParam) + testutil.CompareAuditLogs(t, wantAuditLogs, log.String()) + } }) } }