diff --git a/internal/federationdomain/endpoints/login/login_handler.go b/internal/federationdomain/endpoints/login/login_handler.go index 288694db1..04e077c38 100644 --- a/internal/federationdomain/endpoints/login/login_handler.go +++ b/internal/federationdomain/endpoints/login/login_handler.go @@ -6,6 +6,8 @@ package login import ( "net/http" + "k8s.io/apimachinery/pkg/util/sets" + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" "go.pinniped.dev/internal/auditevent" "go.pinniped.dev/internal/federationdomain/endpoints/login/loginhtml" @@ -25,6 +27,13 @@ type HandlerFunc func( decodedState *oidc.UpstreamStateParamData, ) error +func paramsSafeToLog() sets.Set[string] { + return sets.New[string]( + // This param is sometimes added by the POST login handler when redirecting back to the GET login handler. + "err", + ) +} + // NewHandler returns a http.Handler that serves the login endpoint for IDPs that don't have their own web UI for login. // // This handler takes care of the shared concerns between the GET and POST methods of the login endpoint: @@ -43,6 +52,11 @@ func NewHandler( auditLogger plog.AuditLogger, ) http.Handler { loginHandler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + if err := auditLogger.AuditRequestParams(r, paramsSafeToLog()); err != nil { + plog.DebugErr("error parsing callback request params", err) + return httperr.New(http.StatusBadRequest, "error parsing request params") + } + var handler HandlerFunc switch r.Method { case http.MethodGet: diff --git a/internal/federationdomain/endpoints/login/login_handler_test.go b/internal/federationdomain/endpoints/login/login_handler_test.go index e3a463a76..9b7fd877a 100644 --- a/internal/federationdomain/endpoints/login/login_handler_test.go +++ b/internal/federationdomain/endpoints/login/login_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "go.pinniped.dev/internal/auditid" + "go.pinniped.dev/internal/federationdomain/endpoints/loginurl" "go.pinniped.dev/internal/federationdomain/oidc" "go.pinniped.dev/internal/federationdomain/stateparam" "go.pinniped.dev/internal/httputil/httperr" @@ -134,7 +135,11 @@ func TestLoginEndpoint(t *testing.T) { wantContentType: htmlContentType, wantBody: "Method Not Allowed: PUT (try GET or POST)\n", wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { - return []testutil.WantedAuditLog{} + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted"}, + }), + } }, }, { @@ -200,7 +205,11 @@ func TestLoginEndpoint(t *testing.T) { wantContentType: htmlContentType, wantBody: "Bad Request: state param not found\n", wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { - return []testutil.WantedAuditLog{} + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{}, + }), + } }, }, { @@ -295,6 +304,17 @@ func TestLoginEndpoint(t *testing.T) { wantContentType: htmlContentType, wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n", }, + { + name: "GET request with invalid form", + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec), + ).String() + "&invalid;;param", + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: error parsing request params\n", + }, { name: "POST request when upstream IDP type in state param is not supported by this endpoint", method: http.MethodPost, @@ -330,6 +350,27 @@ func TestLoginEndpoint(t *testing.T) { wantEncodedState: happyState, wantDecodedState: expectedHappyDecodedUpstreamStateParam(), }, + { + name: "happy GET request with err param which can be set by the real POST handler on redirects back to the GET handler", + method: http.MethodGet, + path: happyPathWithState + "&" + loginurl.ErrParamName + "=" + string(loginurl.ShowBadUserPassErr), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyGetResult, + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { + return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted", "err": "login_error"}, + }), + testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ + "authorizeID": encodedStateParam.AuthorizeID(), + }), + } + }, + }, { name: "happy GET request for LDAP upstream", method: http.MethodGet, @@ -342,6 +383,9 @@ func TestLoginEndpoint(t *testing.T) { wantDecodedState: expectedHappyDecodedUpstreamStateParam(), wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted"}, + }), testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ "authorizeID": encodedStateParam.AuthorizeID(), }), @@ -360,6 +404,9 @@ func TestLoginEndpoint(t *testing.T) { wantDecodedState: expectedHappyDecodedUpstreamStateParam(), wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted"}, + }), testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ "authorizeID": encodedStateParam.AuthorizeID(), }), @@ -378,6 +425,9 @@ func TestLoginEndpoint(t *testing.T) { wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted"}, + }), testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ "authorizeID": encodedStateParam.AuthorizeID(), }), @@ -396,6 +446,9 @@ func TestLoginEndpoint(t *testing.T) { wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog { return []testutil.WantedAuditLog{ + testutil.WantAuditLog("HTTP Request Parameters", map[string]any{ + "params": map[string]any{"state": "redacted"}, + }), testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{ "authorizeID": encodedStateParam.AuthorizeID(), }), diff --git a/internal/plog/plog_test.go b/internal/plog/plog_test.go index e087e531c..bb296ea95 100644 --- a/internal/plog/plog_test.go +++ b/internal/plog/plog_test.go @@ -218,7 +218,6 @@ func TestAuditRequestParams(t *testing.T) { "baz": []string{"baz1", "baz2"}, } req := httptest.NewRequestWithContext(context.Background(), "GET", "/?"+params.Encode(), nil) - req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) return req }, paramsSafeToLog: sets.New("foo"), @@ -234,7 +233,6 @@ func TestAuditRequestParams(t *testing.T) { "baz": []string{"baz1", "baz2"}, } req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader(params.Encode())) - req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") return req }, @@ -247,7 +245,6 @@ func TestAuditRequestParams(t *testing.T) { name: "get request with bad form", req: func() *http.Request { req := httptest.NewRequestWithContext(context.Background(), "GET", "/?invalid;;;form", nil) - req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) return req }, paramsSafeToLog: sets.New("foo"), @@ -263,7 +260,6 @@ func TestAuditRequestParams(t *testing.T) { name: "post request with bad urlencoded form in body", req: func() *http.Request { req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("invalid;;;form")) - req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") return req }, @@ -280,7 +276,6 @@ func TestAuditRequestParams(t *testing.T) { name: "post request with bad multipart form in body", req: func() *http.Request { req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("this is not a valid multipart form")) - req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) req.Header.Set("Content-Type", "multipart/form-data") return req }, @@ -301,7 +296,10 @@ func TestAuditRequestParams(t *testing.T) { l, actualAuditLogs := TestAuditLogger(t) - rawErr := l.AuditRequestParams(test.req(), test.paramsSafeToLog) + req := test.req() + req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" }) + + rawErr := l.AuditRequestParams(req, test.paramsSafeToLog) if test.wantErr == nil { require.NoError(t, rawErr)