audit log request params on GET and POST login handlers

This commit is contained in:
Ryan Richard
2024-11-13 13:34:45 -08:00
committed by Joshua Casey
parent 51d1cc7a96
commit b54365c199
3 changed files with 73 additions and 8 deletions

View File

@@ -6,6 +6,8 @@ package login
import (
"net/http"
"k8s.io/apimachinery/pkg/util/sets"
idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1"
"go.pinniped.dev/internal/auditevent"
"go.pinniped.dev/internal/federationdomain/endpoints/login/loginhtml"
@@ -25,6 +27,13 @@ type HandlerFunc func(
decodedState *oidc.UpstreamStateParamData,
) error
func paramsSafeToLog() sets.Set[string] {
return sets.New[string](
// This param is sometimes added by the POST login handler when redirecting back to the GET login handler.
"err",
)
}
// NewHandler returns a http.Handler that serves the login endpoint for IDPs that don't have their own web UI for login.
//
// This handler takes care of the shared concerns between the GET and POST methods of the login endpoint:
@@ -43,6 +52,11 @@ func NewHandler(
auditLogger plog.AuditLogger,
) http.Handler {
loginHandler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if err := auditLogger.AuditRequestParams(r, paramsSafeToLog()); err != nil {
plog.DebugErr("error parsing callback request params", err)
return httperr.New(http.StatusBadRequest, "error parsing request params")
}
var handler HandlerFunc
switch r.Method {
case http.MethodGet:

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/auditid"
"go.pinniped.dev/internal/federationdomain/endpoints/loginurl"
"go.pinniped.dev/internal/federationdomain/oidc"
"go.pinniped.dev/internal/federationdomain/stateparam"
"go.pinniped.dev/internal/httputil/httperr"
@@ -134,7 +135,11 @@ func TestLoginEndpoint(t *testing.T) {
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: PUT (try GET or POST)\n",
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{}
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted"},
}),
}
},
},
{
@@ -200,7 +205,11 @@ func TestLoginEndpoint(t *testing.T) {
wantContentType: htmlContentType,
wantBody: "Bad Request: state param not found\n",
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{}
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{},
}),
}
},
},
{
@@ -295,6 +304,17 @@ func TestLoginEndpoint(t *testing.T) {
wantContentType: htmlContentType,
wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n",
},
{
name: "GET request with invalid form",
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec),
).String() + "&invalid;;param",
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: error parsing request params\n",
},
{
name: "POST request when upstream IDP type in state param is not supported by this endpoint",
method: http.MethodPost,
@@ -330,6 +350,27 @@ func TestLoginEndpoint(t *testing.T) {
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
},
{
name: "happy GET request with err param which can be set by the real POST handler on redirects back to the GET handler",
method: http.MethodGet,
path: happyPathWithState + "&" + loginurl.ErrParamName + "=" + string(loginurl.ShowBadUserPassErr),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: htmlContentType,
wantBody: happyGetResult,
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted", "err": "login_error"},
}),
testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{
"authorizeID": encodedStateParam.AuthorizeID(),
}),
}
},
},
{
name: "happy GET request for LDAP upstream",
method: http.MethodGet,
@@ -342,6 +383,9 @@ func TestLoginEndpoint(t *testing.T) {
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted"},
}),
testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{
"authorizeID": encodedStateParam.AuthorizeID(),
}),
@@ -360,6 +404,9 @@ func TestLoginEndpoint(t *testing.T) {
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted"},
}),
testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{
"authorizeID": encodedStateParam.AuthorizeID(),
}),
@@ -378,6 +425,9 @@ func TestLoginEndpoint(t *testing.T) {
wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(),
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted"},
}),
testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{
"authorizeID": encodedStateParam.AuthorizeID(),
}),
@@ -396,6 +446,9 @@ func TestLoginEndpoint(t *testing.T) {
wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(),
wantAuditLogs: func(encodedStateParam stateparam.Encoded) []testutil.WantedAuditLog {
return []testutil.WantedAuditLog{
testutil.WantAuditLog("HTTP Request Parameters", map[string]any{
"params": map[string]any{"state": "redacted"},
}),
testutil.WantAuditLog("AuthorizeID From Parameters", map[string]any{
"authorizeID": encodedStateParam.AuthorizeID(),
}),

View File

@@ -218,7 +218,6 @@ func TestAuditRequestParams(t *testing.T) {
"baz": []string{"baz1", "baz2"},
}
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?"+params.Encode(), nil)
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
return req
},
paramsSafeToLog: sets.New("foo"),
@@ -234,7 +233,6 @@ func TestAuditRequestParams(t *testing.T) {
"baz": []string{"baz1", "baz2"},
}
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader(params.Encode()))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
@@ -247,7 +245,6 @@ func TestAuditRequestParams(t *testing.T) {
name: "get request with bad form",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "GET", "/?invalid;;;form", nil)
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
return req
},
paramsSafeToLog: sets.New("foo"),
@@ -263,7 +260,6 @@ func TestAuditRequestParams(t *testing.T) {
name: "post request with bad urlencoded form in body",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("invalid;;;form"))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
@@ -280,7 +276,6 @@ func TestAuditRequestParams(t *testing.T) {
name: "post request with bad multipart form in body",
req: func() *http.Request {
req := httptest.NewRequestWithContext(context.Background(), "POST", "/", strings.NewReader("this is not a valid multipart form"))
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
@@ -301,7 +296,10 @@ func TestAuditRequestParams(t *testing.T) {
l, actualAuditLogs := TestAuditLogger(t)
rawErr := l.AuditRequestParams(test.req(), test.paramsSafeToLog)
req := test.req()
req, _ = auditid.NewRequestWithAuditID(req, func() string { return "some-audit-id" })
rawErr := l.AuditRequestParams(req, test.paramsSafeToLog)
if test.wantErr == nil {
require.NoError(t, rawErr)