From 81b9a484377f10a6e87f944870f8e7242d17de12 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 13 Nov 2020 12:31:39 -0500 Subject: [PATCH 01/57] callback_handler.go: initial API/test shape with 1 test Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 21 +++++ .../oidc/callback/callback_handler_test.go | 79 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 internal/oidc/callback/callback_handler.go create mode 100644 internal/oidc/callback/callback_handler_test.go diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go new file mode 100644 index 000000000..837c38c65 --- /dev/null +++ b/internal/oidc/callback/callback_handler.go @@ -0,0 +1,21 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package callback provides a handler for the OIDC callback endpoint. +package callback + +import ( + "net/http" + + "go.pinniped.dev/internal/httputil/httperr" +) + +func NewHandler() http.Handler { + return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + if r.Method != http.MethodGet { + return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + } + + return nil + }) +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go new file mode 100644 index 000000000..2f88aa303 --- /dev/null +++ b/internal/oidc/callback/callback_handler_test.go @@ -0,0 +1,79 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package callback + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCallbackEndpoint(t *testing.T) { + tests := []struct { + name string + + method string + + wantStatus int + wantBody string + }{ + // Happy path + // TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code + + // Pre-upstream-exchange verification + { + name: "PUT method is invalid", + method: http.MethodPut, + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: PUT (try GET)\n", + }, + // TODO: POST/PATCH/DELETE is invalid + // TODO: request has body? maybe we don't need to do anything... + // TODO: code does not exist + // TODO: we got called twice with the same state and cookie...is this bad? might be ok if the client's first roundtrip failed + // TODO: we got called twice with the same state and cookie and the UpstreamOIDCProvider CRD has been deleted + // TODO: state does not exist + // TODO: invalid signature on state + // TODO: state is expired (the expiration is encoded in the state itself) + // TODO: state csrf value does not match csrf cookie + // TODO: cookie does not exist + // TODO: invalid signature on cookie + // TODO: state version does not match what we want + + // Upstream exchange + // TODO: we can't figure out what the upstream token endpoint is (do we get this UpstreamOIDCProvider name from the path?) + // TODO: network call to upstream token endpoint fails + // TODO: the upstream token endpoint returns an error + + // Post-upstream-exchange verification + // TODO: returned tokens are invalid (all the stuff from the spec...) + // TODO: there + // TODO: are + // TODO: probably + // TODO: a + // TODO: lot + // TODO: of + // TODO: test + // TODO: cases + // TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly) + + // Downstream redirect + // TODO: cannot generate auth code + // TODO: cannot persist downstream state + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + subject := NewHandler() + req := httptest.NewRequest(test.method, "/path-is-not-yet-tested", nil /* body not yet tested */) + rsp := httptest.NewRecorder() + subject.ServeHTTP(rsp, req) + + require.Equal(t, test.wantStatus, rsp.Code) + require.Equal(t, test.wantBody, rsp.Body.String()) + }) + } +} From 3ef1171667d7333f97be1ccf054b72bed93c3746 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 13 Nov 2020 15:59:51 -0800 Subject: [PATCH 02/57] Tiny bit more code for Supervisor's callback_handler.go Signed-off-by: Ryan Richard --- internal/oidc/auth/auth_handler.go | 9 +- internal/oidc/auth/auth_handler_test.go | 81 +++----- internal/oidc/callback/callback_handler.go | 31 ++- .../oidc/callback/callback_handler_test.go | 186 ++++++++++++++++-- internal/oidc/oidc.go | 6 + internal/oidc/provider/manager/manager.go | 4 +- internal/testutil/oidc.go | 26 +++ 7 files changed, 268 insertions(+), 75 deletions(-) create mode 100644 internal/testutil/oidc.go diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index de3c7f714..a22e50289 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient/nonce" @@ -42,10 +43,6 @@ const ( csrfCookieEncodingName = "csrf" ) -type IDPListGetter interface { - GetIDPList() []provider.UpstreamOIDCIdentityProvider -} - // This is the encoding side of the securecookie.Codec interface. type Encoder interface { Encode(name string, value interface{}) (string, error) @@ -53,7 +50,7 @@ type Encoder interface { func NewHandler( issuer string, - idpListGetter IDPListGetter, + idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), @@ -178,7 +175,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } -func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { +func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { return nil, httperr.New( diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index e0fdbacd4..173021d84 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -23,6 +23,7 @@ import ( "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" + "go.pinniped.dev/internal/testutil" ) func TestAuthorizationEndpoint(t *testing.T) { @@ -210,7 +211,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrf = csrfValueOverride } encoded, err := happyStateEncoder.Encode("s", - expectedUpstreamStateParamFormat{ + testutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), N: happyNonce, C: csrf, @@ -270,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET without a CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -288,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET with a CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -306,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using POST", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -326,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -348,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream redirect uri does not match what is configured for client", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -365,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream client does not exist", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -380,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "response type is unsupported", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -396,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream scopes do not match what is configured for client", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -412,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing response type in request", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -428,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing client id in request", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -443,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -459,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -475,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -491,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -509,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -525,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "OIDC validations are skipped when the openid scope was not requested", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -546,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "state does not have enough entropy", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -562,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding upstream state param", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -577,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding CSRF cookie value for new cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -592,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating CSRF token", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -607,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating nonce", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -622,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating PKCE", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -637,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while decoding CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -653,7 +654,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "no upstream providers are configured", issuer: issuer, - idpListGetter: newIDPListGetter(), // empty + idpListGetter: testutil.NewIDPListGetter(), // empty method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -663,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "too many upstream providers are configured", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -673,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PUT is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -683,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PATCH is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -693,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "DELETE is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -805,18 +806,6 @@ func TestAuthorizationEndpoint(t *testing.T) { }) } -// Declare a separate type from the production code to ensure that the state param's contents was serialized -// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of -// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality -// assertions about the redirect URL in this test. -type expectedUpstreamStateParamFormat struct { - P string `json:"p"` - N string `json:"n"` - C string `json:"c"` - K string `json:"k"` - V string `json:"v"` -} - type errorReturningEncoder struct { securecookie.Codec } @@ -850,13 +839,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL expectedQueryStateParam := expectedLocationURL.Query().Get("state") require.NotEmpty(t, expectedQueryStateParam) - var expectedDecodedStateParam expectedUpstreamStateParamFormat + var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) require.NoError(t, err) actualQueryStateParam := actualLocationURL.Query().Get("state") require.NotEmpty(t, actualQueryStateParam) - var actualDecodedStateParam expectedUpstreamStateParamFormat + var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) require.NoError(t, err) @@ -884,9 +873,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore } require.Equal(t, expectedLocationQuery, actualLocationQuery) } - -func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { - idpProvider := provider.NewDynamicUpstreamIDPProvider() - idpProvider.SetIDPList(upstreamOIDCIdentityProviders) - return idpProvider -} diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 837c38c65..475008e5e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -6,16 +6,43 @@ package callback import ( "net/http" + "path" "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/provider" ) -func NewHandler() http.Handler { +func NewHandler( + idpListGetter oidc.IDPListGetter, +) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodGet { return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } - return nil + if r.FormValue("code") == "" { + return httperr.New(http.StatusBadRequest, "code param not found") + } + + if r.FormValue("state") == "" { + return httperr.New(http.StatusBadRequest, "state param not found") + } + + if findUpstreamIDPConfig(r, idpListGetter) == nil { + return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") + } + + return httperr.New(http.StatusBadRequest, "state param not valid") }) } + +func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { + _, lastPathComponent := path.Split(r.URL.Path) + for _, p := range idpListGetter.GetIDPList() { + if p.Name == lastPathComponent { + return &p + } + } + return nil +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 2f88aa303..05635016f 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -4,18 +4,76 @@ package callback import ( + "fmt" "net/http" "net/http/httptest" + "net/url" "testing" + "github.com/gorilla/securecookie" "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/testutil" +) + +const ( + happyUpstreamIDPName = "upstream-idp-name" ) func TestCallbackEndpoint(t *testing.T) { + upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") + require.NoError(t, err) + otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") + require.NoError(t, err) + + upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + AuthorizationURL: *upstreamAuthURL, + Scopes: []string{"scope1", "scope2"}, + } + + otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + Name: "other-upstream-idp-name", + ClientID: "other-some-client-id", + AuthorizationURL: *otherUpstreamAuthURL, + Scopes: []string{"other-scope1", "other-scope2"}, + } + + var stateEncoderHashKey = []byte("fake-hash-secret") + var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + var cookieEncoderHashKey = []byte("fake-hash-secret2") + var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateEncoder.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) + + //happyCSRF := "test-csrf" + //happyPKCE := "test-pkce" + //happyNonce := "test-nonce" + // + //happyEncodedState, err := happyStateEncoder.Encode("s", + // testutil.ExpectedUpstreamStateParamFormat{ + // P: "todo query goes here", + // N: happyNonce, + // C: happyCSRF, + // K: happyPKCE, + // V: "1", + // }, + //) + //require.NoError(t, err) + tests := []struct { name string - method string + method string + path string + idpListGetter provider.DynamicUpstreamIDPProvider wantStatus int wantBody string @@ -27,24 +85,67 @@ func TestCallbackEndpoint(t *testing.T) { { name: "PUT method is invalid", method: http.MethodPut, + path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, wantBody: "Method Not Allowed: PUT (try GET)\n", }, - // TODO: POST/PATCH/DELETE is invalid - // TODO: request has body? maybe we don't need to do anything... - // TODO: code does not exist - // TODO: we got called twice with the same state and cookie...is this bad? might be ok if the client's first roundtrip failed - // TODO: we got called twice with the same state and cookie and the UpstreamOIDCProvider CRD has been deleted - // TODO: state does not exist - // TODO: invalid signature on state - // TODO: state is expired (the expiration is encoded in the state itself) - // TODO: state csrf value does not match csrf cookie - // TODO: cookie does not exist - // TODO: invalid signature on cookie - // TODO: state version does not match what we want + { + name: "POST method is invalid", + method: http.MethodPost, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: POST (try GET)\n", + }, + { + name: "PATCH method is invalid", + method: http.MethodPatch, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: PATCH (try GET)\n", + }, + { + name: "DELETE method is invalid", + method: http.MethodDelete, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: DELETE (try GET)\n", + }, + { + name: "code param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutCode().String(), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", + }, + { + name: "state param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not valid\n", + }, + { + name: "the UpstreamOIDCProvider CRD has been deleted", + method: http.MethodGet, + path: newRequestPath().String(), + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: upstream provider not found\n", + }, + // TODO: csrf cookie does not exist on request + // TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem) + // TODO: csrf value from inside state param does not match csrf cookie value + // TODO: state's internal version does not match what we want // Upstream exchange - // TODO: we can't figure out what the upstream token endpoint is (do we get this UpstreamOIDCProvider name from the path?) // TODO: network call to upstream token endpoint fails // TODO: the upstream token endpoint returns an error @@ -61,14 +162,15 @@ func TestCallbackEndpoint(t *testing.T) { // TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly) // Downstream redirect + // TODO: we grant the openid scope if it was requested, similar to what we did in auth_handler.go // TODO: cannot generate auth code // TODO: cannot persist downstream state } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler() - req := httptest.NewRequest(test.method, "/path-is-not-yet-tested", nil /* body not yet tested */) + subject := NewHandler(test.idpListGetter) + req := httptest.NewRequest(test.method, test.path, nil) rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) @@ -77,3 +179,55 @@ func TestCallbackEndpoint(t *testing.T) { }) } } + +type requestPath struct { + upstreamIDPName, code, state *string +} + +func newRequestPath() *requestPath { + n := happyUpstreamIDPName + c := "1234" + s := "4321" + return &requestPath{ + upstreamIDPName: &n, + code: &c, + state: &s, + } +} + +func (r *requestPath) WithUpstreamIDPName(name string) *requestPath { + r.upstreamIDPName = &name + return r +} + +func (r *requestPath) WithCode(code string) *requestPath { + r.code = &code + return r +} + +func (r *requestPath) WithoutCode() *requestPath { + r.code = nil + return r +} + +func (r *requestPath) WithState(state string) *requestPath { + r.state = &state + return r +} + +func (r *requestPath) WithoutState() *requestPath { + r.state = nil + return r +} + +func (r *requestPath) String() string { + path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName) + params := url.Values{} + if r.code != nil { + params.Add("code", *r.code) + } + if r.state != nil { + params.Add("state", *r.state) + } + return path + params.Encode() +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 8d319c1d1..f0f74970f 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -7,6 +7,8 @@ package oidc import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + + "go.pinniped.dev/internal/oidc/provider" ) const ( @@ -49,3 +51,7 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by compose.OAuth2PKCEFactory, ) } + +type IDPListGetter interface { + GetIDPList() []provider.UpstreamOIDCIdentityProvider +} diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 0687ba22d..f009693d0 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -30,14 +30,14 @@ type Manager struct { providerHandlers map[string]http.Handler // map of all routes for all providers nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data - idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs + idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. -func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager { +func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager { return &Manager{ providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler, diff --git a/internal/testutil/oidc.go b/internal/testutil/oidc.go new file mode 100644 index 000000000..0552ea95f --- /dev/null +++ b/internal/testutil/oidc.go @@ -0,0 +1,26 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import "go.pinniped.dev/internal/oidc/provider" + +// Test helpers for the OIDC package. + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { + idpProvider := provider.NewDynamicUpstreamIDPProvider() + idpProvider.SetIDPList(upstreamOIDCIdentityProviders) + return idpProvider +} + +// Declare a separate type from the production code to ensure that the state param's contents was serialized +// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of +// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality +// assertions about the redirect URL in this test. +type ExpectedUpstreamStateParamFormat struct { + P string `json:"p"` + N string `json:"n"` + C string `json:"c"` + K string `json:"k"` + V string `json:"v"` +} From 4138c9244fb25afa9b7809598a25c90f0483ee91 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Mon, 16 Nov 2020 11:47:49 -0500 Subject: [PATCH 03/57] callback_handler.go: write 2 invalid cookie tests Also common-ize some more constants shared between the auth and callback endpoints. Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler.go | 17 +--- internal/oidc/callback/callback_handler.go | 28 ++++++ .../oidc/callback/callback_handler_test.go | 93 +++++++++++++------ internal/oidc/oidc.go | 11 +++ 4 files changed, 108 insertions(+), 41 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index a22e50289..96cb44645 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -34,16 +34,9 @@ const ( // The `name` passed to the encoder for encoding the upstream state param value. This name is short // because it will be encoded into the upstream state param value and we're trying to keep that small. upstreamStateParamEncodingName = "s" - - // The name of the browser cookie which shall hold our CSRF value. - // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes - csrfCookieName = "__Host-pinniped-csrf" - - // The `name` passed to the encoder for encoding and decoding the CSRF cookie contents. - csrfCookieEncodingName = "csrf" ) -// This is the encoding side of the securecookie.Codec interface. +// Encoder is the encoding side of the securecookie.Codec interface. type Encoder interface { Encode(name string, value interface{}) (string, error) } @@ -152,14 +145,14 @@ func NewHandler( } func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { - receivedCSRFCookie, err := r.Cookie(csrfCookieName) + receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found return "", nil } var csrfFromCookie csrftoken.CSRFToken - err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) if err != nil { return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) } @@ -242,13 +235,13 @@ func upstreamStateParam( } func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { - encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue) + encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue) if err != nil { return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) } http.SetCookie(w, &http.Cookie{ - Name: csrfCookieName, + Name: oidc.CSRFCookieName, Value: encodedCSRFValue, HttpOnly: true, SameSite: http.SameSiteStrictMode, diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 475008e5e..9c5ad6fd1 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -10,17 +10,29 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" ) +// Decoder is the decoding side of the securecookie.Codec interface. +type Decoder interface { + Decode(name, value string, into interface{}) error +} + func NewHandler( idpListGetter oidc.IDPListGetter, + cookieDecoder Decoder, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodGet { return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } + _, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + return err + } + if r.FormValue("code") == "" { return httperr.New(http.StatusBadRequest, "code param not found") } @@ -46,3 +58,19 @@ func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *p } return nil } + +func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) { + receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) + if err != nil { + // Error means that the cookie was not found + return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + } + + var csrfFromCookie csrftoken.CSRFToken + err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + if err != nil { + return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + } + + return csrfFromCookie, nil +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 05635016f..8d966db41 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -53,27 +53,34 @@ func TestCallbackEndpoint(t *testing.T) { var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) - //happyCSRF := "test-csrf" - //happyPKCE := "test-pkce" - //happyNonce := "test-nonce" + // happyCSRF := "test-csrf" + // happyPKCE := "test-pkce" + // happyNonce := "test-nonce" // - //happyEncodedState, err := happyStateEncoder.Encode("s", - // testutil.ExpectedUpstreamStateParamFormat{ - // P: "todo query goes here", - // N: happyNonce, - // C: happyCSRF, - // K: happyPKCE, - // V: "1", - // }, - //) - //require.NoError(t, err) + // happyEncodedState, err := happyStateEncoder.Encode("s", + // testutil.ExpectedUpstreamStateParamFormat{ + // P: "todo query goes here", + // N: happyNonce, + // C: happyCSRF, + // K: happyPKCE, + // V: "1", + // }, + // ) + // require.NoError(t, err) + + incomingCookieCSRFValue := "csrf-value-from-cookie" + encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue) + require.NoError(t, err) + happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue tests := []struct { name string + idpListGetter provider.DynamicUpstreamIDPProvider + cookieDecoder Decoder method string path string - idpListGetter provider.DynamicUpstreamIDPProvider + csrfCookie string wantStatus int wantBody string @@ -111,37 +118,62 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Method Not Allowed: DELETE (try GET)\n", }, { - name: "code param was not included on request", - method: http.MethodGet, - path: newRequestPath().WithoutCode().String(), - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: code param not found\n", + name: "code param was not included on request", + cookieDecoder: happyCookieEncoder, + method: http.MethodGet, + path: newRequestPath().WithoutCode().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", }, { - name: "state param was not included on request", - method: http.MethodGet, - path: newRequestPath().WithoutState().String(), - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: state param not found\n", + name: "state param was not included on request", + cookieDecoder: happyCookieEncoder, + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + cookieDecoder: happyCookieEncoder, method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, wantBody: "Bad Request: state param not valid\n", }, { name: "the UpstreamOIDCProvider CRD has been deleted", + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + cookieDecoder: happyCookieEncoder, method: http.MethodGet, path: newRequestPath().String(), - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: upstream provider not found\n", }, - // TODO: csrf cookie does not exist on request - // TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem) + { + name: "the CSRF cookie does not exist on request", + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + cookieDecoder: happyCookieEncoder, + method: http.MethodGet, + path: newRequestPath().String(), + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: unauthorized request\n", + }, + { + name: "the CSRF cookie cannot be decoded", + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + cookieDecoder: happyCookieEncoder, + method: http.MethodGet, + path: newRequestPath().String(), + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: unauthorized request\n", + }, // TODO: csrf value from inside state param does not match csrf cookie value // TODO: state's internal version does not match what we want @@ -169,8 +201,11 @@ func TestCallbackEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.idpListGetter) + subject := NewHandler(test.idpListGetter, test.cookieDecoder) req := httptest.NewRequest(test.method, test.path, nil) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index f0f74970f..3b5dcb60c 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -18,6 +18,17 @@ const ( JWKSEndpointPath = "/jwks.json" ) +const ( + // CSRFCookieName is the name of the browser cookie which shall hold our CSRF value. + // The `__Host` prefix has a special meaning. See + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes. + CSRFCookieName = "__Host-pinniped-csrf" + + // CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF + // cookie contents. + CSRFCookieEncodingName = "csrf" +) + func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { return &fosite.DefaultOpenIDConnectClient{ DefaultClient: &fosite.DefaultClient{ From 052cdc40dcdbbdad644b9e745980996c4fc429de Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Mon, 16 Nov 2020 14:41:00 -0500 Subject: [PATCH 04/57] callback_handler.go: add CSRF and version state validations Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler.go | 51 ++----- internal/oidc/auth/auth_handler_test.go | 8 +- internal/oidc/callback/callback_handler.go | 84 ++++++++---- .../oidc/callback/callback_handler_test.go | 128 +++++++++++------- internal/oidc/oidc.go | 42 ++++++ 5 files changed, 197 insertions(+), 116 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 96cb44645..18c5e98ed 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -9,8 +9,6 @@ import ( "net/http" "time" - "github.com/gorilla/securecookie" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" @@ -25,22 +23,6 @@ import ( "go.pinniped.dev/internal/plog" ) -const ( - // Just in case we need to make a breaking change to the format of the upstream state param, - // we are including a format version number. This gives the opportunity for a future version of Pinniped - // to have the consumer of this format decide to reject versions that it doesn't understand. - upstreamStateParamFormatVersion = "1" - - // The `name` passed to the encoder for encoding the upstream state param value. This name is short - // because it will be encoded into the upstream state param value and we're trying to keep that small. - upstreamStateParamEncodingName = "s" -) - -// Encoder is the encoding side of the securecookie.Codec interface. -type Encoder interface { - Encode(name string, value interface{}) (string, error) -} - func NewHandler( issuer string, idpListGetter oidc.IDPListGetter, @@ -48,8 +30,8 @@ func NewHandler( generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), generateNonce func() (nonce.Nonce, error), - upstreamStateEncoder Encoder, - cookieCodec securecookie.Codec, + upstreamStateEncoder oidc.Encoder, + cookieCodec oidc.Codec, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodPost && r.Method != http.MethodGet { @@ -144,7 +126,7 @@ func NewHandler( }) } -func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { +func readCSRFCookie(r *http.Request, codec oidc.Codec) (csrftoken.CSRFToken, error) { receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found @@ -204,37 +186,28 @@ func generateValues( return csrfValue, nonceValue, pkceValue, nil } -// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param. -type upstreamStateParamData struct { - AuthParams string `json:"p"` - Nonce nonce.Nonce `json:"n"` - CSRFToken csrftoken.CSRFToken `json:"c"` - PKCECode pkce.Code `json:"k"` - StateParamFormatVersion string `json:"v"` -} - func upstreamStateParam( authorizeRequester fosite.AuthorizeRequester, nonceValue nonce.Nonce, csrfValue csrftoken.CSRFToken, pkceValue pkce.Code, - encoder Encoder, + encoder oidc.Encoder, ) (string, error) { - stateParamData := upstreamStateParamData{ - AuthParams: authorizeRequester.GetRequestForm().Encode(), - Nonce: nonceValue, - CSRFToken: csrfValue, - PKCECode: pkceValue, - StateParamFormatVersion: upstreamStateParamFormatVersion, + stateParamData := oidc.UpstreamStateParamData{ + AuthParams: authorizeRequester.GetRequestForm().Encode(), + Nonce: nonceValue, + CSRFToken: csrfValue, + PKCECode: pkceValue, + FormatVersion: oidc.UpstreamStateParamFormatVersion, } - encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData) + encodedStateParamValue, err := encoder.Encode(oidc.UpstreamStateParamEncodingName, stateParamData) if err != nil { return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err) } return encodedStateParamValue, nil } -func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { +func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec oidc.Codec) error { encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue) if err != nil { return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 173021d84..6f38a8d95 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -249,8 +249,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) - stateEncoder securecookie.Codec - cookieEncoder securecookie.Codec + stateEncoder oidc.Codec + cookieEncoder oidc.Codec method string path string contentType string @@ -807,7 +807,7 @@ func TestAuthorizationEndpoint(t *testing.T) { } type errorReturningEncoder struct { - securecookie.Codec + oidc.Codec } func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { @@ -830,7 +830,7 @@ func requireEqualContentType(t *testing.T, actual string, expected string) { require.Equal(t, actualContentTypeParams, expectedContentTypeParams) } -func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) { +func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) { t.Helper() actualLocationURL, err := url.Parse(actualURL) require.NoError(t, err) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 9c5ad6fd1..5bd04127f 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -12,43 +12,62 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/plog" ) -// Decoder is the decoding side of the securecookie.Codec interface. -type Decoder interface { - Decode(name, value string, into interface{}) error -} - func NewHandler( idpListGetter oidc.IDPListGetter, - cookieDecoder Decoder, + stateDecoder, cookieDecoder oidc.Decoder, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) - } - - _, err := readCSRFCookie(r, cookieDecoder) - if err != nil { + if err := validateRequest(r, stateDecoder, cookieDecoder); err != nil { return err } - if r.FormValue("code") == "" { - return httperr.New(http.StatusBadRequest, "code param not found") - } - - if r.FormValue("state") == "" { - return httperr.New(http.StatusBadRequest, "state param not found") - } - if findUpstreamIDPConfig(r, idpListGetter) == nil { + plog.Warning("upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") } - return httperr.New(http.StatusBadRequest, "state param not valid") + return nil }) } +func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) error { + if r.Method != http.MethodGet { + return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + } + + csrfValue, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + plog.InfoErr("error reading CSRF cookie", err) + return err + } + + if r.FormValue("code") == "" { + plog.Info("code param not found") + return httperr.New(http.StatusBadRequest, "code param not found") + } + + if r.FormValue("state") == "" { + plog.Info("state param not found") + return httperr.New(http.StatusBadRequest, "state param not found") + } + + state, err := readState(r, stateDecoder) + if err != nil { + plog.InfoErr("error reading state", err) + return err + } + + if state.CSRFToken != csrfValue { + plog.InfoErr("CSRF value does not match", err) + return httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) + } + + return nil +} + func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { _, lastPathComponent := path.Split(r.URL.Path) for _, p := range idpListGetter.GetIDPList() { @@ -59,18 +78,35 @@ func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *p return nil } -func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) { +func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) { receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found - return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) } var csrfFromCookie csrftoken.CSRFToken err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) if err != nil { - return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) } return csrfFromCookie, nil } + +func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { + var state oidc.UpstreamStateParamData + if err := stateDecoder.Decode( + oidc.UpstreamStateParamEncodingName, + r.FormValue("state"), + &state, + ); err != nil { + return nil, httperr.New(http.StatusBadRequest, "error reading state") + } + + if state.FormatVersion != oidc.UpstreamStateParamFormatVersion { + return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") + } + + return &state, nil +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 8d966db41..46dff9e27 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -48,28 +48,49 @@ func TestCallbackEndpoint(t *testing.T) { require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) - var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) - happyStateEncoder.SetSerializer(securecookie.JSONEncoder{}) - var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) - happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) + var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) - // happyCSRF := "test-csrf" - // happyPKCE := "test-pkce" - // happyNonce := "test-nonce" - // - // happyEncodedState, err := happyStateEncoder.Encode("s", - // testutil.ExpectedUpstreamStateParamFormat{ - // P: "todo query goes here", - // N: happyNonce, - // C: happyCSRF, - // K: happyPKCE, - // V: "1", - // }, - // ) - // require.NoError(t, err) + happyCSRF := "test-csrf" + happyPKCE := "test-pkce" + happyNonce := "test-nonce" - incomingCookieCSRFValue := "csrf-value-from-cookie" - encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue) + happyState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: "1", + }, + ) + require.NoError(t, err) + + wrongCSRFValueState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: "wrong-csrf-value", + K: happyPKCE, + V: "1", + }, + ) + require.NoError(t, err) + + wrongVersionState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: "wrong-version", + }, + ) + require.NoError(t, err) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue @@ -77,7 +98,6 @@ func TestCallbackEndpoint(t *testing.T) { name string idpListGetter provider.DynamicUpstreamIDPProvider - cookieDecoder Decoder method string path string csrfCookie string @@ -118,39 +138,44 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Method Not Allowed: DELETE (try GET)\n", }, { - name: "code param was not included on request", - cookieDecoder: happyCookieEncoder, - method: http.MethodGet, - path: newRequestPath().WithoutCode().String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: code param not found\n", + name: "code param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithoutCode().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", }, { - name: "state param was not included on request", - cookieDecoder: happyCookieEncoder, - method: http.MethodGet, - path: newRequestPath().WithoutState().String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: state param not found\n", + name: "state param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: state param not valid\n", + wantBody: "Bad Request: error reading state\n", + }, + { + name: "state's internal version does not match what we want", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(wrongVersionState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: state format version is invalid\n", }, { name: "the UpstreamOIDCProvider CRD has been deleted", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: upstream provider not found\n", @@ -158,24 +183,29 @@ func TestCallbackEndpoint(t *testing.T) { { name: "the CSRF cookie does not exist on request", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, - wantBody: "Forbidden: unauthorized request\n", + wantBody: "Forbidden: CSRF cookie is missing\n", }, { - name: "the CSRF cookie cannot be decoded", + name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", wantStatus: http.StatusForbidden, - wantBody: "Forbidden: unauthorized request\n", + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "cookie csrf value does not match state csrf value", + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(wrongCSRFValueState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF value does not match\n", }, - // TODO: csrf value from inside state param does not match csrf cookie value - // TODO: state's internal version does not match what we want // Upstream exchange // TODO: network call to upstream token endpoint fails @@ -201,7 +231,7 @@ func TestCallbackEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.idpListGetter, test.cookieDecoder) + subject := NewHandler(test.idpListGetter, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 3b5dcb60c..afa06cc42 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -8,7 +8,10 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidcclient/nonce" + "go.pinniped.dev/internal/oidcclient/pkce" ) const ( @@ -19,6 +22,15 @@ const ( ) const ( + // Just in case we need to make a breaking change to the format of the upstream state param, + // we are including a format version number. This gives the opportunity for a future version of Pinniped + // to have the consumer of this format decide to reject versions that it doesn't understand. + UpstreamStateParamFormatVersion = "1" + + // The `name` passed to the encoder for encoding the upstream state param value. This name is short + // because it will be encoded into the upstream state param value and we're trying to keep that small. + UpstreamStateParamEncodingName = "s" + // CSRFCookieName is the name of the browser cookie which shall hold our CSRF value. // The `__Host` prefix has a special meaning. See // https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes. @@ -29,6 +41,36 @@ const ( CSRFCookieEncodingName = "csrf" ) +// Encoder is the encoding side of the securecookie.Codec interface. +type Encoder interface { + Encode(name string, value interface{}) (string, error) +} + +// Decoder is the decoding side of the securecookie.Codec interface. +type Decoder interface { + Decode(name, value string, into interface{}) error +} + +// Codec is both the encoding and decoding sides of the securecookie.Codec interface. It is +// interface'd here so that we properly wrap the securecookie dependency. +type Codec interface { + Encoder + Decoder +} + +// UpstreamStateParamData is the format of the state parameter that we use when we communicate to an +// upstream OIDC provider. +// +// Keep the JSON to a minimal size because the upstream provider could impose size limitations on +// the state param. +type UpstreamStateParamData struct { + AuthParams string `json:"p"` + Nonce nonce.Nonce `json:"n"` + CSRFToken csrftoken.CSRFToken `json:"c"` + PKCECode pkce.Code `json:"k"` + FormatVersion string `json:"v"` +} + func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { return &fosite.DefaultOpenIDConnectClient{ DefaultClient: &fosite.DefaultClient{ From 1c7601a2b53300a4ba729fe32c603625b56020b3 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Mon, 16 Nov 2020 17:07:34 -0500 Subject: [PATCH 05/57] callback_handler.go: start happy path test with redirect Next steps: fosite storage? Signed-off-by: Ryan Richard --- internal/oidc/callback/callback_handler.go | 34 +++++++--- .../oidc/callback/callback_handler_test.go | 65 +++++++++++++++++-- 2 files changed, 84 insertions(+), 15 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 5bd04127f..d204d361c 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,7 +5,9 @@ package callback import ( + "fmt" "net/http" + "net/url" "path" "go.pinniped.dev/internal/httputil/httperr" @@ -20,7 +22,8 @@ func NewHandler( stateDecoder, cookieDecoder oidc.Decoder, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - if err := validateRequest(r, stateDecoder, cookieDecoder); err != nil { + state, err := validateRequest(r, stateDecoder, cookieDecoder) + if err != nil { return err } @@ -29,43 +32,56 @@ func NewHandler( return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") } + downstreamAuthParams, err := url.ParseQuery(state.AuthParams) + if err != nil { + panic(err) + } + + downstreamCallbackURL := fmt.Sprintf( + "%s?code=%s&state=%s", + downstreamAuthParams.Get("redirect_uri"), + url.QueryEscape("some-code"), + url.QueryEscape(downstreamAuthParams.Get("state")), + ) + http.Redirect(w, r, downstreamCallbackURL, 302) + return nil }) } -func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) error { +func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } csrfValue, err := readCSRFCookie(r, cookieDecoder) if err != nil { plog.InfoErr("error reading CSRF cookie", err) - return err + return nil, err } if r.FormValue("code") == "" { plog.Info("code param not found") - return httperr.New(http.StatusBadRequest, "code param not found") + return nil, httperr.New(http.StatusBadRequest, "code param not found") } if r.FormValue("state") == "" { plog.Info("state param not found") - return httperr.New(http.StatusBadRequest, "state param not found") + return nil, httperr.New(http.StatusBadRequest, "state param not found") } state, err := readState(r, stateDecoder) if err != nil { plog.InfoErr("error reading state", err) - return err + return nil, err } if state.CSRFToken != csrfValue { plog.InfoErr("CSRF value does not match", err) - return httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) + return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) } - return nil + return state, nil } func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 46dff9e27..29147a08b 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -5,9 +5,11 @@ package callback import ( "fmt" + "html" "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/gorilla/securecookie" @@ -22,6 +24,10 @@ const ( ) func TestCallbackEndpoint(t *testing.T) { + const ( + downstreamRedirectURI = "http://127.0.0.1/callback" + ) + upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") @@ -53,13 +59,25 @@ func TestCallbackEndpoint(t *testing.T) { var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + happyDownstreamState := "some-downstream-state" + + happyOrignalRequestParams := url.Values{ + "response_type": []string{"code"}, + "scope": []string{"openid profile email"}, + "client_id": []string{"pinniped-cli"}, + "state": []string{happyDownstreamState}, + "nonce": []string{"some-nonce-value"}, + "code_challenge": []string{"some-challenge"}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{downstreamRedirectURI}, + }.Encode() happyCSRF := "test-csrf" happyPKCE := "test-pkce" happyNonce := "test-nonce" happyState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -70,7 +88,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongCSRFValueState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: "wrong-csrf-value", K: happyPKCE, @@ -81,7 +99,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongVersionState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -102,11 +120,22 @@ func TestCallbackEndpoint(t *testing.T) { path string csrfCookie string - wantStatus int - wantBody string + wantStatus int + wantBody string + wantRedirectLocationRegexp string }{ // Happy path // TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code + { + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&state=` + happyDownstreamState, + }, + // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) // Pre-upstream-exchange verification { @@ -240,7 +269,31 @@ func TestCallbackEndpoint(t *testing.T) { subject.ServeHTTP(rsp, req) require.Equal(t, test.wantStatus, rsp.Code) - require.Equal(t, test.wantBody, rsp.Body.String()) + + require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions") + switch { + case test.wantBody != "": + require.Empty(t, rsp.Header().Values("Location")) + require.Equal(t, test.wantBody, rsp.Body.String()) + case test.wantRedirectLocationRegexp != "": + // Assert that Location header matches regular expression. + require.Len(t, rsp.Header().Values("Location"), 1) + actualLocation := rsp.Header().Get("Location") + regex := regexp.MustCompile(test.wantRedirectLocationRegexp) + submatches := regex.FindStringSubmatch(actualLocation) + require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation) + capturedAuthCode := submatches[1] + _ = capturedAuthCode + + // Assert capturedAuthCode storage stuff... + + // Assert that body contains anchor tag with redirect location. + anchorTagWithLocationHref := fmt.Sprintf("Found.\n\n", html.EscapeString(actualLocation)) + require.Equal(t, anchorTagWithLocationHref, rsp.Body.String()) + default: + require.Empty(t, rsp.Header().Values("Location")) + require.Empty(t, rsp.Body.String()) + } }) } } From 227fbd63aaa7334d1dd51ac4f1e24ef14f4185cf Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 18 Nov 2020 13:38:13 -0800 Subject: [PATCH 06/57] Use an interface instead of a concrete type for UpstreamOIDCIdentityProvider Because we want it to implement an AuthcodeExchanger interface and do it in a way that will be more unit test-friendly than the underlying library that we intend to use inside its implementation. --- .../upstreamwatcher/upstreamwatcher.go | 8 +- .../upstreamwatcher/upstreamwatcher_test.go | 12 ++- internal/oidc/auth/auth_handler.go | 12 +-- internal/oidc/auth/auth_handler_test.go | 8 +- internal/oidc/callback/callback_handler.go | 11 ++- .../oidc/callback/callback_handler_test.go | 61 ++++++++----- internal/oidc/oidc.go | 2 +- .../provider/dynamic_upstream_idp_provider.go | 90 ++++++++++++++++--- .../oidc/provider/manager/manager_test.go | 15 ++-- internal/testutil/oidc.go | 66 +++++++++++++- 10 files changed, 220 insertions(+), 65 deletions(-) diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go index bc3db3bf7..9c8952e9b 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go @@ -62,7 +62,7 @@ const ( // IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. type IDPCache interface { - SetIDPList([]provider.UpstreamOIDCIdentityProvider) + SetIDPList([]provider.UpstreamOIDCIdentityProviderI) } // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. @@ -132,13 +132,13 @@ func (c *controller) Sync(ctx controllerlib.Context) error { } requeue := false - validatedUpstreams := make([]provider.UpstreamOIDCIdentityProvider, 0, len(actualUpstreams)) + validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams)) for _, upstream := range actualUpstreams { valid := c.validateUpstream(ctx, upstream) if valid == nil { requeue = true } else { - validatedUpstreams = append(validatedUpstreams, *valid) + validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid)) } } c.cache.SetIDPList(validatedUpstreams) @@ -258,6 +258,8 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) } + // TODO also parse the token endpoint from the discovery info and put it onto the `result` + // Parse out and validate the discovered authorize endpoint. authURL, err := url.Parse(discoveredProvider.Endpoint().AuthURL) if err != nil { diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go index 949effaf7..4891ae8a3 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go @@ -527,7 +527,9 @@ func TestController(t *testing.T) { kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) testLog := testlogger.New(t) cache := provider.NewDynamicUpstreamIDPProvider() - cache.SetIDPList([]provider.UpstreamOIDCIdentityProvider{{Name: "initial-entry"}}) + initialProviderList := make([]provider.UpstreamOIDCIdentityProviderI, 1) + initialProviderList[0] = &provider.UpstreamOIDCIdentityProvider{Name: "initial-entry"} + cache.SetIDPList(initialProviderList) controller := New( cache, @@ -551,7 +553,13 @@ func TestController(t *testing.T) { require.NoError(t, err) } require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n")) - require.ElementsMatch(t, tt.wantResultingCache, cache.GetIDPList()) + + actualIDPList := cache.GetIDPList() + require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) + for i := range actualIDPList { + actualIDP := actualIDPList[i].(*provider.UpstreamOIDCIdentityProvider) + require.Equal(t, tt.wantResultingCache[i], *actualIDP) + } actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{}) require.NoError(t, err) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 18c5e98ed..a6fa63b4f 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -88,12 +88,12 @@ func NewHandler( } upstreamOAuthConfig := oauth2.Config{ - ClientID: upstreamIDP.ClientID, + ClientID: upstreamIDP.GetClientID(), Endpoint: oauth2.Endpoint{ - AuthURL: upstreamIDP.AuthorizationURL.String(), + AuthURL: upstreamIDP.GetAuthorizationURL().String(), }, - RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.Name), - Scopes: upstreamIDP.Scopes, + RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.GetName()), + Scopes: upstreamIDP.GetScopes(), } encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) @@ -150,7 +150,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } -func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { +func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (provider.UpstreamOIDCIdentityProviderI, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { return nil, httperr.New( @@ -163,7 +163,7 @@ func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDC "Too many upstream providers are configured (support for multiple upstreams is not yet implemented)", ) } - return &allUpstreamIDPs[0], nil + return allUpstreamIDPs[0], nil } func generateValues( diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 6f38a8d95..8d7448683 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -113,7 +113,7 @@ func TestAuthorizationEndpoint(t *testing.T) { upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) - upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ Name: "some-idp", ClientID: "some-client-id", AuthorizationURL: *upstreamAuthURL, @@ -122,7 +122,7 @@ func TestAuthorizationEndpoint(t *testing.T) { issuer := "https://my-issuer.com/some-path" - // Configure fosite the same way that the production code would, except use in-memory storage. + // Configure fosite the same way that the production code would, using NullStorage to turn off storage. oauthStore := oidc.NullStorage{} hmacSecret := []byte("some secret - must have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") @@ -771,13 +771,13 @@ func TestAuthorizationEndpoint(t *testing.T) { runOneTestCase(t, test, subject) // Call the setter to change the upstream IDP settings. - newProviderSettings := provider.UpstreamOIDCIdentityProvider{ + newProviderSettings := testutil.TestUpstreamOIDCIdentityProvider{ Name: "some-other-idp", ClientID: "some-other-client-id", AuthorizationURL: *upstreamAuthURL, Scopes: []string{"other-scope1", "other-scope2"}, } - test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{newProviderSettings}) + test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)}) // Update the expectations of the test case to match the new upstream IDP settings. test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index d204d361c..6d9803f4d 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -10,6 +10,8 @@ import ( "net/url" "path" + "github.com/ory/fosite" + "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" @@ -17,10 +19,7 @@ import ( "go.pinniped.dev/internal/plog" ) -func NewHandler( - idpListGetter oidc.IDPListGetter, - stateDecoder, cookieDecoder oidc.Decoder, -) http.Handler { +func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) if err != nil { @@ -84,10 +83,10 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return state, nil } -func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { +func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProviderI { _, lastPathComponent := path.Split(r.URL.Path) for _, p := range idpListGetter.GetIDPList() { - if p.Name == lastPathComponent { + if p.GetName() == lastPathComponent { return &p } } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 29147a08b..8315b9100 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -13,8 +13,11 @@ import ( "testing" "github.com/gorilla/securecookie" + "github.com/ory/fosite" + "github.com/ory/fosite/storage" "github.com/stretchr/testify/require" + "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" ) @@ -28,23 +31,41 @@ func TestCallbackEndpoint(t *testing.T) { downstreamRedirectURI = "http://127.0.0.1/callback" ) - upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") - require.NoError(t, err) - otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") - require.NoError(t, err) + // TODO use a fosite memory store and pass in a fostite oauthHelper + // TODO write a test double for UpstreamOIDCIdentityProviderI ID token with a claim called "the-user-claim" and put a username as the value of that claim + // TODO assert that after the callback request, the fosite storage has 1 authcode key saved, + // and it is the same key that was returned in the redirect, + // and the value in storage includes the username in the fosite session + // TODO do the same thing with the groups list (store it in the fosite session as JWT claim) + // TODO test for when UpstreamOIDCIdentityProviderI authcode exchange fails + // TODO wire in the callback endpoint into the oidc manager request router + // TODO update the upstream watcher controller to also populate the new fields + // TODO update the integration test + // TODO DO NOT store the upstream tokens (or maybe just the refresh token) for this story. In a future story, we can store them/it in some other storage interface indexed by the same authcode hash that fosite used for storage. + // TODO grab the upstream config name from the state param instead of the URL path - upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - AuthorizationURL: *upstreamAuthURL, - Scopes: []string{"scope1", "scope2"}, + // Configure fosite the same way that the production code would, except use in-memory storage. + oauthStore := &storage.MemoryStore{ + Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()}, + AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, + PKCES: map[string]fosite.Requester{}, + IDSessions: map[string]fosite.Requester{}, + } + hmacSecret := []byte("some secret - must have at least 32 bytes") + require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") + oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) + + upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + UsernameClaim: "the-user-claim", + Scopes: []string{"scope1", "scope2"}, } - otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ - Name: "other-upstream-idp-name", - ClientID: "other-some-client-id", - AuthorizationURL: *otherUpstreamAuthURL, - Scopes: []string{"other-scope1", "other-scope2"}, + otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: "other-upstream-idp-name", + ClientID: "other-some-client-id", + Scopes: []string{"other-scope1", "other-scope2"}, } var stateEncoderHashKey = []byte("fake-hash-secret") @@ -61,7 +82,7 @@ func TestCallbackEndpoint(t *testing.T) { happyDownstreamState := "some-downstream-state" - happyOrignalRequestParams := url.Values{ + happyOriginalRequestParams := url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, "client_id": []string{"pinniped-cli"}, @@ -77,7 +98,7 @@ func TestCallbackEndpoint(t *testing.T) { happyState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: happyOrignalRequestParams, + P: happyOriginalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -88,7 +109,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongCSRFValueState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: happyOrignalRequestParams, + P: happyOriginalRequestParams, N: happyNonce, C: "wrong-csrf-value", K: happyPKCE, @@ -99,7 +120,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongVersionState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: happyOrignalRequestParams, + P: happyOriginalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -260,7 +281,7 @@ func TestCallbackEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.idpListGetter, happyStateCodec, happyCookieCodec) + subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -285,7 +306,7 @@ func TestCallbackEndpoint(t *testing.T) { capturedAuthCode := submatches[1] _ = capturedAuthCode - // Assert capturedAuthCode storage stuff... + // TODO Assert capturedAuthCode storage stuff... // Assert that body contains anchor tag with redirect location. anchorTagWithLocationHref := fmt.Sprintf("Found.\n\n", html.EscapeString(actualLocation)) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index afa06cc42..4da9cfcdf 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -106,5 +106,5 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by } type IDPListGetter interface { - GetIDPList() []provider.UpstreamOIDCIdentityProvider + GetIDPList() []provider.UpstreamOIDCIdentityProviderI } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index bb26cef2b..41ffb7d19 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -4,48 +4,114 @@ package provider import ( + "context" "net/url" "sync" + + "go.pinniped.dev/internal/oidcclient" + "go.pinniped.dev/internal/oidcclient/nonce" + "go.pinniped.dev/internal/oidcclient/pkce" ) -type UpstreamOIDCIdentityProvider struct { +type UpstreamOIDCIdentityProviderI interface { // A name for this upstream provider, which will be used as a component of the path for the callback endpoint // hosted by the Supervisor. - Name string + GetName() string - // The Oauth client ID registered with the upstream provider to be used in the authorization flow. - ClientID string + // The Oauth client ID registered with the upstream provider to be used in the authorization code flow. + GetClientID() string // The Authorization Endpoint fetched from discovery. - AuthorizationURL url.URL + GetAuthorizationURL() *url.URL // Scopes to request in authorization flow. - Scopes []string + GetScopes() []string + + // ID Token username claim name. May return empty string, in which case we will use some reasonable defaults. + GetUsernameClaim() string + + // ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider. + GetGroupsClaim() string + + AuthcodeExchanger +} + +// Performs upstream OIDC authorization code exchange and token validation. +// Returns the validated raw tokens as well as the parsed claims of the ID token. +type AuthcodeExchanger interface { + ExchangeAuthcodeAndValidateTokens( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, + ) (tokens oidcclient.Token, parsedIDTokenClaims map[string]interface{}, err error) +} + +type UpstreamOIDCIdentityProvider struct { + Name string + ClientID string + AuthorizationURL url.URL + UsernameClaim string + GroupsClaim string + Scopes []string +} + +func (u *UpstreamOIDCIdentityProvider) GetName() string { + return u.Name +} + +func (u *UpstreamOIDCIdentityProvider) GetClientID() string { + return u.ClientID +} + +func (u *UpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL { + return &u.AuthorizationURL +} + +func (u *UpstreamOIDCIdentityProvider) GetScopes() []string { + return u.Scopes +} + +func (u *UpstreamOIDCIdentityProvider) GetUsernameClaim() string { + return u.UsernameClaim +} + +func (u *UpstreamOIDCIdentityProvider) GetGroupsClaim() string { + return u.GroupsClaim +} + +func (u *UpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, +) (oidcclient.Token, map[string]interface{}, error) { + panic("TODO implement me") // TODO } type DynamicUpstreamIDPProvider interface { - SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) - GetIDPList() []UpstreamOIDCIdentityProvider + SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) + GetIDPList() []UpstreamOIDCIdentityProviderI } type dynamicUpstreamIDPProvider struct { - oidcProviders []UpstreamOIDCIdentityProvider + oidcProviders []UpstreamOIDCIdentityProviderI mutex sync.RWMutex } func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider { return &dynamicUpstreamIDPProvider{ - oidcProviders: []UpstreamOIDCIdentityProvider{}, + oidcProviders: []UpstreamOIDCIdentityProviderI{}, } } -func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) { +func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) { p.mutex.Lock() // acquire a write lock defer p.mutex.Unlock() p.oidcProviders = oidcIDPs } -func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProvider { +func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProviderI { p.mutex.RLock() // acquire a read lock defer p.mutex.RUnlock() return p.oidcProviders diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index e2d55f322..e9e175d8a 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" + "go.pinniped.dev/internal/testutil" + "github.com/sclevine/spec" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" @@ -107,14 +109,11 @@ func TestManager(t *testing.T) { parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) r.NoError(err) - idpListGetter := provider.NewDynamicUpstreamIDPProvider() - idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{ - { - Name: "test-idp", - ClientID: "test-client-id", - AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, - Scopes: []string{"test-scope"}, - }, + idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{ + Name: "test-idp", + ClientID: "test-client-id", + AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, + Scopes: []string{"test-scope"}, }) subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) diff --git a/internal/testutil/oidc.go b/internal/testutil/oidc.go index 0552ea95f..14bdb92ba 100644 --- a/internal/testutil/oidc.go +++ b/internal/testutil/oidc.go @@ -3,13 +3,73 @@ package testutil -import "go.pinniped.dev/internal/oidc/provider" +import ( + "context" + "net/url" + + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidcclient" + "go.pinniped.dev/internal/oidcclient/nonce" + "go.pinniped.dev/internal/oidcclient/pkce" +) // Test helpers for the OIDC package. -func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { +type TestUpstreamOIDCIdentityProvider struct { + Name string + ClientID string + AuthorizationURL url.URL + UsernameClaim string + GroupsClaim string + Scopes []string + ExchangeAuthcodeAndValidateTokensFunc func( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, + ) (oidcclient.Token, map[string]interface{}, error) +} + +func (u *TestUpstreamOIDCIdentityProvider) GetName() string { + return u.Name +} + +func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string { + return u.ClientID +} + +func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL { + return &u.AuthorizationURL +} + +func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string { + return u.Scopes +} + +func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string { + return u.UsernameClaim +} + +func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string { + return u.GroupsClaim +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, +) (oidcclient.Token, map[string]interface{}, error) { + return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) +} + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { idpProvider := provider.NewDynamicUpstreamIDPProvider() - idpProvider.SetIDPList(upstreamOIDCIdentityProviders) + upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) + for i := range upstreamOIDCIdentityProviders { + upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i]) + } + idpProvider.SetIDPList(upstreams) return idpProvider } From 652ea6bd2a46444996217baa29b14882d2b013bf Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 18 Nov 2020 17:15:01 -0800 Subject: [PATCH 07/57] Start using fosite in the Supervisor's callback handler --- internal/oidc/callback/callback_handler.go | 65 +++++++++-- .../oidc/callback/callback_handler_test.go | 107 +++++++++--------- 2 files changed, 106 insertions(+), 66 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 6d9803f4d..cd66f646b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,12 +5,14 @@ package callback import ( - "fmt" "net/http" "net/url" "path" + "time" "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/token/jwt" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" @@ -26,23 +28,64 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi return err } - if findUpstreamIDPConfig(r, idpListGetter) == nil { + upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter) + if upstreamIDPConfig == nil { plog.Warning("upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") } downstreamAuthParams, err := url.ParseQuery(state.AuthParams) if err != nil { - panic(err) + panic(err) // TODO } - downstreamCallbackURL := fmt.Sprintf( - "%s?code=%s&state=%s", - downstreamAuthParams.Get("redirect_uri"), - url.QueryEscape("some-code"), - url.QueryEscape(downstreamAuthParams.Get("state")), + // Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest(). + reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams} + authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest) + if err != nil { + panic(err) // TODO + } + + // TODO: grant the openid scope only if it was requested, similar to what we did in auth_handler.go + authorizeRequester.GrantScope("openid") + + _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( + r.Context(), + "TODO", // TODO use the upstream authcode (code param) here + "TODO", // TODO use the pkce value from the decoded state param here + "TODO", // TODO use the nonce value from the decoded state param here ) - http.Redirect(w, r, downstreamCallbackURL, 302) + if err != nil { + panic(err) // TODO + } + + var username string + // TODO handle the case when upstreamIDPConfig.GetUsernameClaim() is the empty string by defaulting to something reasonable + usernameAsInterface := idTokenClaims[upstreamIDPConfig.GetUsernameClaim()] + username, ok := usernameAsInterface.(string) + if !ok { + panic(err) // TODO + } + + // TODO also look at the upstream ID token's groups claim and store that value as a downstream ID token claim + + now := time.Now() + authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Issuer: "https://fosite.my-application.com", // TODO use the right value here + Subject: username, + Audience: []string{"my-client"}, // TODO use the right value here + ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here + IssuedAt: now, // TODO test this + RequestedAt: now, // TODO test this + AuthTime: now, // TODO test this + }, + }) + if err != nil { + panic(err) // TODO + } + + oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) return nil }) @@ -83,11 +126,11 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return state, nil } -func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProviderI { +func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { _, lastPathComponent := path.Split(r.URL.Path) for _, p := range idpListGetter.GetIDPList() { if p.GetName() == lastPathComponent { - return &p + return p } } return nil diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 8315b9100..632ef32b4 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -4,21 +4,26 @@ package callback import ( + "context" "fmt" - "html" "net/http" "net/http/httptest" "net/url" "regexp" + "strings" "testing" "github.com/gorilla/securecookie" "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/storage" "github.com/stretchr/testify/require" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidcclient" + "go.pinniped.dev/internal/oidcclient/nonce" + "go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/testutil" ) @@ -31,19 +36,6 @@ func TestCallbackEndpoint(t *testing.T) { downstreamRedirectURI = "http://127.0.0.1/callback" ) - // TODO use a fosite memory store and pass in a fostite oauthHelper - // TODO write a test double for UpstreamOIDCIdentityProviderI ID token with a claim called "the-user-claim" and put a username as the value of that claim - // TODO assert that after the callback request, the fosite storage has 1 authcode key saved, - // and it is the same key that was returned in the redirect, - // and the value in storage includes the username in the fosite session - // TODO do the same thing with the groups list (store it in the fosite session as JWT claim) - // TODO test for when UpstreamOIDCIdentityProviderI authcode exchange fails - // TODO wire in the callback endpoint into the oidc manager request router - // TODO update the upstream watcher controller to also populate the new fields - // TODO update the integration test - // TODO DO NOT store the upstream tokens (or maybe just the refresh token) for this story. In a future story, we can store them/it in some other storage interface indexed by the same authcode hash that fosite used for storage. - // TODO grab the upstream config name from the state param instead of the URL path - // Configure fosite the same way that the production code would, except use in-memory storage. oauthStore := &storage.MemoryStore{ Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()}, @@ -59,7 +51,16 @@ func TestCallbackEndpoint(t *testing.T) { Name: happyUpstreamIDPName, ClientID: "some-client-id", UsernameClaim: "the-user-claim", + GroupsClaim: "the-groups-claim", Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, + map[string]interface{}{ + "the-user-claim": "test-pinniped-username", + "other-claim": "should be ignored", + }, + nil + }, } otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ @@ -144,17 +145,19 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus int wantBody string wantRedirectLocationRegexp string + wantAuthcodeStored bool }{ - // Happy path - // TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code { - name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&state=` + happyDownstreamState, + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, + wantAuthcodeStored: true, + wantBody: "", }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -256,27 +259,6 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusForbidden, wantBody: "Forbidden: CSRF value does not match\n", }, - - // Upstream exchange - // TODO: network call to upstream token endpoint fails - // TODO: the upstream token endpoint returns an error - - // Post-upstream-exchange verification - // TODO: returned tokens are invalid (all the stuff from the spec...) - // TODO: there - // TODO: are - // TODO: probably - // TODO: a - // TODO: lot - // TODO: of - // TODO: test - // TODO: cases - // TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly) - - // Downstream redirect - // TODO: we grant the openid scope if it was requested, similar to what we did in auth_handler.go - // TODO: cannot generate auth code - // TODO: cannot persist downstream state } for _, test := range tests { test := test @@ -292,11 +274,14 @@ func TestCallbackEndpoint(t *testing.T) { require.Equal(t, test.wantStatus, rsp.Code) require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions") - switch { - case test.wantBody != "": - require.Empty(t, rsp.Header().Values("Location")) + + if test.wantBody != "" { require.Equal(t, test.wantBody, rsp.Body.String()) - case test.wantRedirectLocationRegexp != "": + } else { + require.Empty(t, rsp.Body.String()) + } + + if test.wantRedirectLocationRegexp != "" { // Assert that Location header matches regular expression. require.Len(t, rsp.Header().Values("Location"), 1) actualLocation := rsp.Header().Get("Location") @@ -304,16 +289,28 @@ func TestCallbackEndpoint(t *testing.T) { submatches := regex.FindStringSubmatch(actualLocation) require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation) capturedAuthCode := submatches[1] - _ = capturedAuthCode - // TODO Assert capturedAuthCode storage stuff... + // One authcode should have been stored. + require.Len(t, oauthStore.AuthorizeCodes, 1) - // Assert that body contains anchor tag with redirect location. - anchorTagWithLocationHref := fmt.Sprintf("Found.\n\n", html.EscapeString(actualLocation)) - require.Equal(t, anchorTagWithLocationHref, rsp.Body.String()) - default: + // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface + authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") + require.Len(t, authcodeDataAndSignature, 2) + + // Get the authcode session back from storage so we can require that it was stored correctly. + storedAuthorizeRequest, err := oauthStore.GetAuthorizeCodeSession(context.Background(), authcodeDataAndSignature[1], nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + _, ok := storedAuthorizeRequest.(*fosite.Request) + require.True(t, ok) + storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) + require.True(t, ok) + + // Check various fields of the stored data. + require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) + } else { require.Empty(t, rsp.Header().Values("Location")) - require.Empty(t, rsp.Body.String()) } }) } From ffdb7fa79501f249fda26b2e95e2f19d07991795 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 08:41:44 -0500 Subject: [PATCH 08/57] callback_handler.go: add a test for invalid state auth params Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 2 +- .../oidc/callback/callback_handler_test.go | 27 ++++++++++++++++--- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index cd66f646b..2491f20cb 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -36,7 +36,7 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi downstreamAuthParams, err := url.ParseQuery(state.AuthParams) if err != nil { - panic(err) // TODO + return httperr.New(http.StatusBadRequest, "error reading state's downstream auth params") } // Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest(). diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 632ef32b4..652b87854 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -96,6 +96,7 @@ func TestCallbackEndpoint(t *testing.T) { happyCSRF := "test-csrf" happyPKCE := "test-pkce" happyNonce := "test-nonce" + happyStateVersion := "1" happyState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ @@ -103,7 +104,7 @@ func TestCallbackEndpoint(t *testing.T) { N: happyNonce, C: happyCSRF, K: happyPKCE, - V: "1", + V: happyStateVersion, }, ) require.NoError(t, err) @@ -114,7 +115,7 @@ func TestCallbackEndpoint(t *testing.T) { N: happyNonce, C: "wrong-csrf-value", K: happyPKCE, - V: "1", + V: happyStateVersion, }, ) require.NoError(t, err) @@ -125,7 +126,18 @@ func TestCallbackEndpoint(t *testing.T) { N: happyNonce, C: happyCSRF, K: happyPKCE, - V: "wrong-version", + V: "wrong-state-version", + }, + ) + require.NoError(t, err) + + wrongDownstreamAuthParamsState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "these-is-not-a-valid-url-query-%z", + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: happyStateVersion, }, ) require.NoError(t, err) @@ -224,6 +236,15 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: state format version is invalid\n", }, + { + name: "state's downstream auth params element is invalid", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state's downstream auth params\n", + }, { name: "the UpstreamOIDCProvider CRD has been deleted", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), From 63b8c6e4b2779a4f7e078de2399776d6ff52b512 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 08:51:23 -0500 Subject: [PATCH 09/57] callback_handler.go: test when state missing a needed param Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 4 +- .../oidc/callback/callback_handler_test.go | 46 +++++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 2491f20cb..a1cc49b4e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -36,14 +36,14 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi downstreamAuthParams, err := url.ParseQuery(state.AuthParams) if err != nil { - return httperr.New(http.StatusBadRequest, "error reading state's downstream auth params") + return httperr.New(http.StatusBadRequest, "error reading state downstream auth params") } // Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest(). reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams} authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest) if err != nil { - panic(err) // TODO + return httperr.New(http.StatusBadRequest, "error using state downstream auth params") } // TODO: grant the openid scope only if it was requested, similar to what we did in auth_handler.go diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 652b87854..eb6de5751 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -83,7 +83,7 @@ func TestCallbackEndpoint(t *testing.T) { happyDownstreamState := "some-downstream-state" - happyOriginalRequestParams := url.Values{ + happyOriginalRequestParamsQuery := url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, "client_id": []string{"pinniped-cli"}, @@ -92,7 +92,8 @@ func TestCallbackEndpoint(t *testing.T) { "code_challenge": []string{"some-challenge"}, "code_challenge_method": []string{"S256"}, "redirect_uri": []string{downstreamRedirectURI}, - }.Encode() + } + happyOriginalRequestParams := happyOriginalRequestParamsQuery.Encode() happyCSRF := "test-csrf" happyPKCE := "test-pkce" happyNonce := "test-nonce" @@ -142,6 +143,17 @@ func TestCallbackEndpoint(t *testing.T) { ) require.NoError(t, err) + missingClientIDState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: shallowCopyQueryExceptFor(happyOriginalRequestParamsQuery, "client_id").Encode(), + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: happyStateVersion, + }, + ) + require.NoError(t, err) + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue @@ -243,7 +255,16 @@ func TestCallbackEndpoint(t *testing.T) { path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error reading state's downstream auth params\n", + wantBody: "Bad Request: error reading state downstream auth params\n", + }, + { + name: "state's downstream auth params are missing required value (e.g., client_id)", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(missingClientIDState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error using state downstream auth params\n", }, { name: "the UpstreamOIDCProvider CRD has been deleted", @@ -388,3 +409,22 @@ func (r *requestPath) String() string { } return path + params.Encode() } + +func shallowCopyQueryExceptFor(query url.Values, keys ...string) url.Values { + copied := url.Values{} + for key, value := range query { + if !contains(keys, key) { + copied[key] = value + } + } + return copied +} + +func contains(haystack []string, needle string) bool { + for _, hay := range haystack { + if hay == needle { + return true + } + } + return false +} From 6c72507bcae43d0078a24278a2f9f05c9949e46e Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 09:00:41 -0500 Subject: [PATCH 10/57] callback_handler.go: add test for failed upstream exchange/validation Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 2 +- .../oidc/callback/callback_handler_test.go | 29 +++++++++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index a1cc49b4e..7d1f555c6 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -56,7 +56,7 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi "TODO", // TODO use the nonce value from the decoded state param here ) if err != nil { - panic(err) // TODO + return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } var username string diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index eb6de5751..afa88c4b8 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -5,6 +5,7 @@ package callback import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -69,6 +70,17 @@ func TestCallbackEndpoint(t *testing.T) { Scopes: []string{"other-scope1", "other-scope2"}, } + failedExchangeUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: upstreamOIDCIdentityProvider.ClientID, + UsernameClaim: upstreamOIDCIdentityProvider.UsernameClaim, + GroupsClaim: upstreamOIDCIdentityProvider.GroupsClaim, + Scopes: upstreamOIDCIdentityProvider.Scopes, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, nil, errors.New("some exchange error") + }, + } + var stateEncoderHashKey = []byte("fake-hash-secret") var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES var cookieEncoderHashKey = []byte("fake-hash-secret2") @@ -277,7 +289,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "the CSRF cookie does not exist on request", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, @@ -285,7 +297,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", @@ -294,13 +306,24 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie csrf value does not match state csrf value", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(wrongCSRFValueState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusForbidden, wantBody: "Forbidden: CSRF value does not match\n", }, + + // Upstream exchange + { + name: "upstream auth code exchange fails", + idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + }, } for _, test := range tests { test := test From 48e02506495d2bda5e68a006d0f6724327c69245 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 09:28:56 -0500 Subject: [PATCH 11/57] callback_handler.go: test that we request openid scope correctly Also add some testing.T.Log() calls to make debugging handler test failures easier. Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler_test.go | 2 + internal/oidc/callback/callback_handler.go | 12 ++- .../oidc/callback/callback_handler_test.go | 76 +++++++++++++------ 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 8d7448683..1ac7cb873 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -711,6 +711,8 @@ func TestAuthorizationEndpoint(t *testing.T) { } rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) + t.Logf("response: %#v", rsp) + t.Logf("response body: %q", rsp.Body.String()) require.Equal(t, test.wantStatus, rsp.Code) requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 7d1f555c6..4208a7468 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -46,8 +46,8 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi return httperr.New(http.StatusBadRequest, "error using state downstream auth params") } - // TODO: grant the openid scope only if it was requested, similar to what we did in auth_handler.go - authorizeRequester.GrantScope("openid") + // Grant the openid scope only if it was requested. + grantOpenIDScopeIfRequested(authorizeRequester) _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), @@ -168,3 +168,11 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP return &state, nil } + +func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { + for _, scope := range authorizeRequester.GetRequestedScopes() { + if scope == "openid" { + authorizeRequester.GrantScope(scope) + } + } +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index afa88c4b8..4b47e96bd 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -37,17 +37,6 @@ func TestCallbackEndpoint(t *testing.T) { downstreamRedirectURI = "http://127.0.0.1/callback" ) - // Configure fosite the same way that the production code would, except use in-memory storage. - oauthStore := &storage.MemoryStore{ - Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()}, - AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, - PKCES: map[string]fosite.Requester{}, - IDSessions: map[string]fosite.Requester{}, - } - hmacSecret := []byte("some secret - must have at least 32 bytes") - require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") - oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) - upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ Name: happyUpstreamIDPName, ClientID: "some-client-id", @@ -157,7 +146,18 @@ func TestCallbackEndpoint(t *testing.T) { missingClientIDState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: shallowCopyQueryExceptFor(happyOriginalRequestParamsQuery, "client_id").Encode(), + P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode(), + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: happyStateVersion, + }, + ) + require.NoError(t, err) + + noOpenidScopeState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode(), N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -181,7 +181,9 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus int wantBody string wantRedirectLocationRegexp string - wantAuthcodeStored bool + // TODO: I am unused... + wantAuthcodeStored bool + wantGrantedOpenidScope bool }{ { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", @@ -193,6 +195,7 @@ func TestCallbackEndpoint(t *testing.T) { // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, wantAuthcodeStored: true, + wantGrantedOpenidScope: true, wantBody: "", }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -278,6 +281,15 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusBadRequest, wantBody: "Bad Request: error using state downstream auth params\n", }, + { + name: "state's downstream auth params does not contain openid scope", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(noOpenidScopeState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + }, { name: "the UpstreamOIDCProvider CRD has been deleted", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), @@ -328,6 +340,18 @@ func TestCallbackEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { + // Configure fosite the same way that the production code would, except use in-memory storage. + // Inject this into our test subject at the last second so we get a fresh storage for every test. + oauthStore := &storage.MemoryStore{ + Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()}, + AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, + PKCES: map[string]fosite.Requester{}, + IDSessions: map[string]fosite.Requester{}, + } + hmacSecret := []byte("some secret - must have at least 32 bytes") + require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") + oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) + subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { @@ -335,6 +359,8 @@ func TestCallbackEndpoint(t *testing.T) { } rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) + t.Logf("response: %#v", rsp) + t.Logf("response body: %q", rsp.Body.String()) require.Equal(t, test.wantStatus, rsp.Code) @@ -367,12 +393,17 @@ func TestCallbackEndpoint(t *testing.T) { require.NoError(t, err) // Check that storage returned the expected concrete data types. - _, ok := storedAuthorizeRequest.(*fosite.Request) + storedRequest, ok := storedAuthorizeRequest.(*fosite.Request) require.True(t, ok) storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) require.True(t, ok) // Check various fields of the stored data. + if test.wantGrantedOpenidScope { + require.Contains(t, storedRequest.GetGrantedScopes(), "openid") + } else { + require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") + } require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) } else { require.Empty(t, rsp.Header().Values("Location")) @@ -433,21 +464,16 @@ func (r *requestPath) String() string { return path + params.Encode() } -func shallowCopyQueryExceptFor(query url.Values, keys ...string) url.Values { +func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { copied := url.Values{} for key, value := range query { - if !contains(keys, key) { + if modification, ok := modifications[key]; ok { + if modification != "" { + copied[key] = []string{modification} + } + } else { copied[key] = value } } return copied } - -func contains(haystack []string, needle string) bool { - for _, hay := range haystack { - if hay == needle { - return true - } - } - return false -} From 2e62be3ebbbcca4399df22c8256009792c50225a Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 10:20:46 -0500 Subject: [PATCH 12/57] callback_handler.go: assert correct args are passed to token exchange Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler_test.go | 54 ++--- internal/oidc/callback/callback_handler.go | 7 +- .../oidc/callback/callback_handler_test.go | 191 ++++++++++-------- .../oidc/provider/manager/manager_test.go | 2 +- internal/testutil/oidc.go | 37 +++- 5 files changed, 173 insertions(+), 118 deletions(-) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 1ac7cb873..df90bde51 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -271,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET without a CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -289,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET with a CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -307,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using POST", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -327,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -349,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream redirect uri does not match what is configured for client", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -366,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream client does not exist", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -381,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "response type is unsupported", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -397,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream scopes do not match what is configured for client", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -413,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing response type in request", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -429,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing client id in request", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -444,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -460,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -476,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -492,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -510,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -526,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "OIDC validations are skipped when the openid scope was not requested", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -547,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "state does not have enough entropy", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -563,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding upstream state param", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -578,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding CSRF cookie value for new cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -593,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating CSRF token", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -608,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating nonce", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -623,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating PKCE", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -638,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while decoding CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -664,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "too many upstream providers are configured", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -674,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PUT is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -684,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PATCH is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -694,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "DELETE is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 4208a7468..985ade317 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -47,13 +47,14 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi } // Grant the openid scope only if it was requested. + // TODO: shouldn't we be potentially granting more scopes than just openid... grantOpenIDScopeIfRequested(authorizeRequester) _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), - "TODO", // TODO use the upstream authcode (code param) here - "TODO", // TODO use the pkce value from the decoded state param here - "TODO", // TODO use the nonce value from the decoded state param here + r.URL.Query().Get("code"), // TODO: do we need to validate this? + state.PKCECode, + state.Nonce, ) if err != nil { return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 4b47e96bd..b6585a3c9 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/require" "go.pinniped.dev/internal/oidc" - "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" @@ -35,6 +34,8 @@ const ( func TestCallbackEndpoint(t *testing.T) { const ( downstreamRedirectURI = "http://127.0.0.1/callback" + + happyUpstreamAuthcode = "upstream-auth-code" ) upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ @@ -170,13 +171,19 @@ func TestCallbackEndpoint(t *testing.T) { require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{ + Authcode: happyUpstreamAuthcode, + PKCECodeVerifier: pkce.Code(happyPKCE), + ExpectedIDTokenNonce: nonce.Nonce(happyNonce), + } + tests := []struct { name string - idpListGetter provider.DynamicUpstreamIDPProvider - method string - path string - csrfCookie string + idp testutil.TestUpstreamOIDCIdentityProvider + method string + path string + csrfCookie string wantStatus int wantBody string @@ -184,19 +191,22 @@ func TestCallbackEndpoint(t *testing.T) { // TODO: I am unused... wantAuthcodeStored bool wantGrantedOpenidScope bool + + wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { - name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, - wantAuthcodeStored: true, - wantGrantedOpenidScope: true, - wantBody: "", + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, + wantAuthcodeStored: true, + wantGrantedOpenidScope: true, + wantBody: "", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -246,95 +256,97 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Bad Request: state param not found\n", }, { - name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState("this-will-not-decode").String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error reading state\n", + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state\n", }, { - name: "state's internal version does not match what we want", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongVersionState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantBody: "Unprocessable Entity: state format version is invalid\n", + name: "state's internal version does not match what we want", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongVersionState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: state format version is invalid\n", }, { - name: "state's downstream auth params element is invalid", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error reading state downstream auth params\n", + name: "state's downstream auth params element is invalid", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state downstream auth params\n", }, { - name: "state's downstream auth params are missing required value (e.g., client_id)", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(missingClientIDState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error using state downstream auth params\n", + name: "state's downstream auth params are missing required value (e.g., client_id)", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(missingClientIDState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error using state downstream auth params\n", }, { - name: "state's downstream auth params does not contain openid scope", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(noOpenidScopeState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + name: "state's downstream auth params does not contain openid scope", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { - name: "the UpstreamOIDCProvider CRD has been deleted", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantBody: "Unprocessable Entity: upstream provider not found\n", + name: "the UpstreamOIDCProvider CRD has been deleted", + idp: otherUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: upstream provider not found\n", }, { - name: "the CSRF cookie does not exist on request", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: CSRF cookie is missing\n", + name: "the CSRF cookie does not exist on request", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF cookie is missing\n", }, { - name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: error reading CSRF cookie\n", + name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: error reading CSRF cookie\n", }, { - name: "cookie csrf value does not match state csrf value", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongCSRFValueState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: CSRF value does not match\n", + name: "cookie csrf value does not match state csrf value", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongCSRFValueState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF value does not match\n", }, // Upstream exchange { - name: "upstream auth code exchange fails", - idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadGateway, - wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + name: "upstream auth code exchange fails", + idp: failedExchangeUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, } for _, test := range tests { @@ -352,7 +364,8 @@ func TestCallbackEndpoint(t *testing.T) { require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) - subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + idpListGetter := testutil.NewIDPListGetter(&test.idp) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -408,6 +421,14 @@ func TestCallbackEndpoint(t *testing.T) { } else { require.Empty(t, rsp.Header().Values("Location")) } + + if test.wantExchangeAndValidateTokensCall != nil { + require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + test.wantExchangeAndValidateTokensCall.Ctx = req.Context() + require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) + } else { + require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + } }) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index e9e175d8a..fce748bb7 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -109,7 +109,7 @@ func TestManager(t *testing.T) { parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) r.NoError(err) - idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{ + idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{ Name: "test-idp", ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, diff --git a/internal/testutil/oidc.go b/internal/testutil/oidc.go index 14bdb92ba..7cbfcf812 100644 --- a/internal/testutil/oidc.go +++ b/internal/testutil/oidc.go @@ -15,6 +15,15 @@ import ( // Test helpers for the OIDC package. +// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). +type ExchangeAuthcodeAndValidateTokenArgs struct { + Ctx context.Context + Authcode string + PKCECodeVerifier pkce.Code + ExpectedIDTokenNonce nonce.Nonce +} + type TestUpstreamOIDCIdentityProvider struct { Name string ClientID string @@ -28,6 +37,9 @@ type TestUpstreamOIDCIdentityProvider struct { pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, ) (oidcclient.Token, map[string]interface{}, error) + + exchangeAuthcodeAndValidateTokensCallCount int + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs } func (u *TestUpstreamOIDCIdentityProvider) GetName() string { @@ -60,14 +72,35 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, ) (oidcclient.Token, map[string]interface{}, error) { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + u.exchangeAuthcodeAndValidateTokensCallCount++ + u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ + Ctx: ctx, + Authcode: authcode, + PKCECodeVerifier: pkceCodeVerifier, + ExpectedIDTokenNonce: expectedIDTokenNonce, + }) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) } -func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int { + return u.exchangeAuthcodeAndValidateTokensCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + return u.exchangeAuthcodeAndValidateTokensArgs[call] +} + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { idpProvider := provider.NewDynamicUpstreamIDPProvider() upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) for i := range upstreamOIDCIdentityProviders { - upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i]) + upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i]) } idpProvider.SetIDPList(upstreams) return idpProvider From ace861f722632a6ff599aecd6983dba2cc3362d3 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 11:08:21 -0500 Subject: [PATCH 13/57] callback_handler.go: get some thoughts down about default upstream claims Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 44 ++++++++++++++-- .../oidc/callback/callback_handler_test.go | 52 +++++++++++++++---- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 985ade317..93073cd6b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -21,6 +21,21 @@ import ( "go.pinniped.dev/internal/plog" ) +const ( + // defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC + // ID token if the upstream OIDC IDP did not tell us to use another claim. + defaultUpstreamUsernameClaim = "sub" + + // defaultUpstreamGroupsClaim is what we will use to extract the groups from an upstream OIDC ID + // token if the upstream OIDC IDP did not tell us to use another claim. + defaultUpstreamGroupsClaim = "groups" + + // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token + // information. + // TODO: should this be per-issuer? Or per version? + downstreamGroupsClaim = "oidc.pinniped.dev/groups" +) + func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) @@ -61,14 +76,32 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi } var username string - // TODO handle the case when upstreamIDPConfig.GetUsernameClaim() is the empty string by defaulting to something reasonable - usernameAsInterface := idTokenClaims[upstreamIDPConfig.GetUsernameClaim()] - username, ok := usernameAsInterface.(string) + usernameClaim := upstreamIDPConfig.GetUsernameClaim() + if usernameClaim == "" { + usernameClaim = defaultUpstreamUsernameClaim + } + usernameAsInterface, ok := idTokenClaims[usernameClaim] + if !ok { + panic(err) // TODO + } + username, ok = usernameAsInterface.(string) if !ok { panic(err) // TODO } - // TODO also look at the upstream ID token's groups claim and store that value as a downstream ID token claim + var groups []string + groupsClaim := upstreamIDPConfig.GetGroupsClaim() + if groupsClaim == "" { + groupsClaim = defaultUpstreamGroupsClaim + } + groupsAsInterface, ok := idTokenClaims[groupsClaim] + if !ok { + panic(err) // TODO + } + groups, ok = groupsAsInterface.([]string) + if !ok { + panic(err) // TODO + } now := time.Now() authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ @@ -80,6 +113,9 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi IssuedAt: now, // TODO test this RequestedAt: now, // TODO test this AuthTime: now, // TODO test this + Extra: map[string]interface{}{ + downstreamGroupsClaim: groups, + }, }, }) if err != nil { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index b6585a3c9..a375f0959 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -47,8 +47,24 @@ func TestCallbackEndpoint(t *testing.T) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { return oidcclient.Token{}, map[string]interface{}{ - "the-user-claim": "test-pinniped-username", - "other-claim": "should be ignored", + "the-user-claim": "test-pinniped-username", + "the-groups-claim": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "other-claim": "should be ignored", + }, + nil + }, + } + + defaultClaimsUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, + map[string]interface{}{ + "sub": "test-pinniped-username", + "groups": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "other-claim": "should be ignored", }, nil }, @@ -177,6 +193,9 @@ func TestCallbackEndpoint(t *testing.T) { ExpectedIDTokenNonce: nonce.Nonce(happyNonce), } + // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it + happyRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState + tests := []struct { name string @@ -195,14 +214,26 @@ func TestCallbackEndpoint(t *testing.T) { wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { - name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idp: upstreamOIDCIdentityProvider, - method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, - // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantAuthcodeStored: true, + wantGrantedOpenidScope: true, + wantBody: "", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP uses default claims", + idp: defaultClaimsUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantAuthcodeStored: true, wantGrantedOpenidScope: true, wantBody: "", @@ -418,6 +449,7 @@ func TestCallbackEndpoint(t *testing.T) { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) + require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) } else { require.Empty(t, rsp.Header().Values("Location")) } From ee84f31f424aa850436c9dccb3769ec1cdb5d25f Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 19 Nov 2020 08:35:23 -0800 Subject: [PATCH 14/57] callback_handler.go: Add JWT Issuer claim to storage --- internal/oidc/auth/auth_handler.go | 4 +- internal/oidc/auth/auth_handler_test.go | 63 +++++++++---------- internal/oidc/callback/callback_handler.go | 9 ++- .../oidc/callback/callback_handler_test.go | 29 ++++----- 4 files changed, 52 insertions(+), 53 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index a6fa63b4f..08c42ada8 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -24,7 +24,7 @@ import ( ) func NewHandler( - issuer string, + downstreamIssuer string, idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, generateCSRF func() (csrftoken.CSRFToken, error), @@ -92,7 +92,7 @@ func NewHandler( Endpoint: oauth2.Endpoint{ AuthURL: upstreamIDP.GetAuthorizationURL().String(), }, - RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.GetName()), + RedirectURL: fmt.Sprintf("%s/callback/%s", downstreamIssuer, upstreamIDP.GetName()), Scopes: upstreamIDP.GetScopes(), } diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index df90bde51..9bfe06e54 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -28,6 +28,7 @@ import ( func TestAuthorizationEndpoint(t *testing.T) { const ( + downstreamIssuer = "https://my-downstream-issuer.com/some-path" downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" ) @@ -120,8 +121,6 @@ func TestAuthorizationEndpoint(t *testing.T) { Scopes: []string{"scope1", "scope2"}, } - issuer := "https://my-issuer.com/some-path" - // Configure fosite the same way that the production code would, using NullStorage to turn off storage. oauthStore := oidc.NullStorage{} hmacSecret := []byte("some secret - must have at least 32 bytes") @@ -233,7 +232,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": issuer + "/callback/some-idp", + "redirect_uri": downstreamIssuer + "/callback/some-idp", }) } @@ -270,7 +269,7 @@ func TestAuthorizationEndpoint(t *testing.T) { tests := []testCase{ { name: "happy path using GET without a CSRF cookie", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -288,7 +287,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "happy path using GET with a CSRF cookie", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -306,7 +305,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "happy path using POST", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -326,7 +325,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -348,7 +347,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream redirect uri does not match what is configured for client", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -365,7 +364,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream client does not exist", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -380,7 +379,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "response type is unsupported", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -396,7 +395,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream scopes do not match what is configured for client", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -412,7 +411,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing response type in request", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -428,7 +427,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing client id in request", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -443,7 +442,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -459,7 +458,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -475,7 +474,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -491,7 +490,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -509,7 +508,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -525,7 +524,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC validations are skipped when the openid scope was not requested", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -546,7 +545,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "state does not have enough entropy", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -562,7 +561,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding upstream state param", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -577,7 +576,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding CSRF cookie value for new cookie", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -592,7 +591,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating CSRF token", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, @@ -607,7 +606,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating nonce", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -622,7 +621,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating PKCE", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, @@ -637,7 +636,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while decoding CSRF cookie", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, @@ -653,7 +652,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "no upstream providers are configured", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(), // empty method: http.MethodGet, path: happyGetRequestPath, @@ -663,7 +662,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "too many upstream providers are configured", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, @@ -673,7 +672,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PUT is a bad method", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", @@ -683,7 +682,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PATCH is a bad method", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", @@ -693,7 +692,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "DELETE is a bad method", - issuer: issuer, + issuer: downstreamIssuer, idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", @@ -792,7 +791,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": issuer + "/callback/some-other-idp", + "redirect_uri": downstreamIssuer + "/callback/some-other-idp", }, ) test.wantBodyString = fmt.Sprintf(`Found.%s`, diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 93073cd6b..009fa737a 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -36,7 +36,12 @@ const ( downstreamGroupsClaim = "oidc.pinniped.dev/groups" ) -func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler { +func NewHandler( + downstreamIssuer string, + idpListGetter oidc.IDPListGetter, + oauthHelper fosite.OAuth2Provider, + stateDecoder, cookieDecoder oidc.Decoder, +) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) if err != nil { @@ -106,7 +111,7 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi now := time.Now() authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ - Issuer: "https://fosite.my-application.com", // TODO use the right value here + Issuer: downstreamIssuer, Subject: username, Audience: []string{"my-client"}, // TODO use the right value here ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index a375f0959..07fc7b18b 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -33,8 +33,8 @@ const ( func TestCallbackEndpoint(t *testing.T) { const ( + downstreamIssuer = "https://my-downstream-issuer.com/path" downstreamRedirectURI = "http://127.0.0.1/callback" - happyUpstreamAuthcode = "upstream-auth-code" ) @@ -207,9 +207,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus int wantBody string wantRedirectLocationRegexp string - // TODO: I am unused... - wantAuthcodeStored bool - wantGrantedOpenidScope bool + wantGrantedOpenidScope bool wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ @@ -221,7 +219,6 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyRedirectLocationRegexp, - wantAuthcodeStored: true, wantGrantedOpenidScope: true, wantBody: "", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, @@ -234,7 +231,6 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyRedirectLocationRegexp, - wantAuthcodeStored: true, wantGrantedOpenidScope: true, wantBody: "", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, @@ -396,7 +392,7 @@ func TestCallbackEndpoint(t *testing.T) { oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) idpListGetter := testutil.NewIDPListGetter(&test.idp) - subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + subject := NewHandler(downstreamIssuer, idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -406,9 +402,15 @@ func TestCallbackEndpoint(t *testing.T) { t.Logf("response: %#v", rsp) t.Logf("response body: %q", rsp.Body.String()) - require.Equal(t, test.wantStatus, rsp.Code) + if test.wantExchangeAndValidateTokensCall != nil { + require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + test.wantExchangeAndValidateTokensCall.Ctx = req.Context() + require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) + } else { + require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + } - require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions") + require.Equal(t, test.wantStatus, rsp.Code) if test.wantBody != "" { require.Equal(t, test.wantBody, rsp.Body.String()) @@ -448,19 +450,12 @@ func TestCallbackEndpoint(t *testing.T) { } else { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } + require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) } else { require.Empty(t, rsp.Header().Values("Location")) } - - if test.wantExchangeAndValidateTokensCall != nil { - require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) - test.wantExchangeAndValidateTokensCall.Ctx = req.Context() - require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) - } else { - require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) - } }) } } From a47617cad07d43462afed231b3770b40971022c5 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 19 Nov 2020 08:53:53 -0800 Subject: [PATCH 15/57] callback_handler.go: Add JWT Audience claim to storage --- internal/oidc/callback/callback_handler.go | 10 ++++++--- .../oidc/callback/callback_handler_test.go | 21 ++++++++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 009fa737a..a2521488a 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -72,7 +72,7 @@ func NewHandler( _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), - r.URL.Query().Get("code"), // TODO: do we need to validate this? + authcode(r), state.PKCECode, state.Nonce, ) @@ -113,7 +113,7 @@ func NewHandler( Claims: &jwt.IDTokenClaims{ Issuer: downstreamIssuer, Subject: username, - Audience: []string{"my-client"}, // TODO use the right value here + Audience: []string{downstreamAuthParams.Get("client_id")}, ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here IssuedAt: now, // TODO test this RequestedAt: now, // TODO test this @@ -133,6 +133,10 @@ func NewHandler( }) } +func authcode(r *http.Request) string { + return r.FormValue("code") +} + func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { if r.Method != http.MethodGet { return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) @@ -144,7 +148,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, err } - if r.FormValue("code") == "" { + if authcode(r) == "" { plog.Info("code param not found") return nil, httperr.New(http.StatusBadRequest, "code param not found") } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 07fc7b18b..a37133718 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -36,6 +36,12 @@ func TestCallbackEndpoint(t *testing.T) { downstreamIssuer = "https://my-downstream-issuer.com/path" downstreamRedirectURI = "http://127.0.0.1/callback" happyUpstreamAuthcode = "upstream-auth-code" + upstreamUsername = "test-pinniped-username" + downstreamClientID = "pinniped-cli" + ) + + var ( + upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} ) upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ @@ -47,8 +53,8 @@ func TestCallbackEndpoint(t *testing.T) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { return oidcclient.Token{}, map[string]interface{}{ - "the-user-claim": "test-pinniped-username", - "the-groups-claim": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "the-user-claim": upstreamUsername, + "the-groups-claim": upstreamGroupMembership, "other-claim": "should be ignored", }, nil @@ -62,8 +68,8 @@ func TestCallbackEndpoint(t *testing.T) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { return oidcclient.Token{}, map[string]interface{}{ - "sub": "test-pinniped-username", - "groups": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "sub": upstreamUsername, + "groups": upstreamGroupMembership, "other-claim": "should be ignored", }, nil @@ -104,7 +110,7 @@ func TestCallbackEndpoint(t *testing.T) { happyOriginalRequestParamsQuery := url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, - "client_id": []string{"pinniped-cli"}, + "client_id": []string{downstreamClientID}, "state": []string{happyDownstreamState}, "nonce": []string{"some-nonce-value"}, "code_challenge": []string{"some-challenge"}, @@ -451,8 +457,9 @@ func TestCallbackEndpoint(t *testing.T) { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) - require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) - require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) + require.Equal(t, upstreamUsername, storedSession.Claims.Subject) + require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience) + require.Equal(t, upstreamGroupMembership, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) } else { require.Empty(t, rsp.Header().Values("Location")) } From 83101eefce044944a92a7041ca6ea26d1f55d1d1 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 19 Nov 2020 14:19:01 -0500 Subject: [PATCH 16/57] callback_handler.go: start to test upstream token corner cases Also refactor to get rid of duplicate test structs. Also also don't default groups ID token claim because there is no standard one. Also also also add some logging that will hopefully help us in debugging in the future. Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 131 ++++++++----- .../oidc/callback/callback_handler_test.go | 172 +++++++++++------- 2 files changed, 189 insertions(+), 114 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index a2521488a..ee9a56216 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -26,14 +26,9 @@ const ( // ID token if the upstream OIDC IDP did not tell us to use another claim. defaultUpstreamUsernameClaim = "sub" - // defaultUpstreamGroupsClaim is what we will use to extract the groups from an upstream OIDC ID - // token if the upstream OIDC IDP did not tell us to use another claim. - defaultUpstreamGroupsClaim = "groups" - // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // information. - // TODO: should this be per-issuer? Or per version? - downstreamGroupsClaim = "oidc.pinniped.dev/groups" + downstreamGroupsClaim = "groups" ) func NewHandler( @@ -56,6 +51,7 @@ func NewHandler( downstreamAuthParams, err := url.ParseQuery(state.AuthParams) if err != nil { + plog.Error("error reading state downstream auth params", err) return httperr.New(http.StatusBadRequest, "error reading state downstream auth params") } @@ -63,11 +59,11 @@ func NewHandler( reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams} authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest) if err != nil { + plog.Error("error using state downstream auth params", err) return httperr.New(http.StatusBadRequest, "error using state downstream auth params") } // Grant the openid scope only if it was requested. - // TODO: shouldn't we be potentially granting more scopes than just openid... grantOpenIDScopeIfRequested(authorizeRequester) _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( @@ -77,52 +73,18 @@ func NewHandler( state.Nonce, ) if err != nil { + plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } - var username string - usernameClaim := upstreamIDPConfig.GetUsernameClaim() - if usernameClaim == "" { - usernameClaim = defaultUpstreamUsernameClaim - } - usernameAsInterface, ok := idTokenClaims[usernameClaim] - if !ok { - panic(err) // TODO - } - username, ok = usernameAsInterface.(string) - if !ok { - panic(err) // TODO + username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + if err != nil { + return err } - var groups []string - groupsClaim := upstreamIDPConfig.GetGroupsClaim() - if groupsClaim == "" { - groupsClaim = defaultUpstreamGroupsClaim - } - groupsAsInterface, ok := idTokenClaims[groupsClaim] - if !ok { - panic(err) // TODO - } - groups, ok = groupsAsInterface.([]string) - if !ok { - panic(err) // TODO - } - - now := time.Now() - authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ - Claims: &jwt.IDTokenClaims{ - Issuer: downstreamIssuer, - Subject: username, - Audience: []string{downstreamAuthParams.Get("client_id")}, - ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here - IssuedAt: now, // TODO test this - RequestedAt: now, // TODO test this - AuthTime: now, // TODO test this - Extra: map[string]interface{}{ - downstreamGroupsClaim: groups, - }, - }, - }) + groups := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups) + authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { panic(err) // TODO } @@ -222,3 +184,76 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } } + +func getUsernameFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) (string, error) { + usernameClaim := upstreamIDPConfig.GetUsernameClaim() + if usernameClaim == "" { + // TODO: if we use the default "sub" claim, maybe we should create the username with the issuer + // since the spec says the "sub" claim is only unique per issuer. + usernameClaim = defaultUpstreamUsernameClaim + } + + usernameAsInterface, ok := idTokenClaims[usernameClaim] + if !ok { + plog.Warning( + "no username claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), + "usernameClaim", usernameClaim, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") + } + + username, ok := usernameAsInterface.(string) + if !ok { + panic("todo bbb") // TODO + } + + return username, nil +} + +func getGroupsFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) []string { + groupsClaim := upstreamIDPConfig.GetGroupsClaim() + if groupsClaim == "" { + return nil + } + + groupsAsInterface, ok := idTokenClaims[groupsClaim] + if !ok { + panic("todo ccc") // TODO + } + + groups, ok := groupsAsInterface.([]string) + if !ok { + panic("todo ddd") // TODO + } + + return groups +} + +func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession { + now := time.Now() + openIDSession := &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Issuer: issuer, + Subject: username, + Audience: []string{clientID}, + ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here + IssuedAt: now, // TODO test this + RequestedAt: now, // TODO test this + AuthTime: now, // TODO test this + }, + } + if groups != nil { + openIDSession.Claims.Extra = map[string]interface{}{ + downstreamGroupsClaim: groups, + } + } + return openIDSession +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index a37133718..6902773bb 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -29,6 +29,16 @@ import ( const ( happyUpstreamIDPName = "upstream-idp-name" + + upstreamSubject = "abc123-some-guid" + upstreamUsername = "test-pinniped-username" + + upstreamUsernameClaim = "the-user-claim" + upstreamGroupsClaim = "the-groups-claim" +) + +var ( + upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} ) func TestCallbackEndpoint(t *testing.T) { @@ -36,63 +46,15 @@ func TestCallbackEndpoint(t *testing.T) { downstreamIssuer = "https://my-downstream-issuer.com/path" downstreamRedirectURI = "http://127.0.0.1/callback" happyUpstreamAuthcode = "upstream-auth-code" - upstreamUsername = "test-pinniped-username" downstreamClientID = "pinniped-cli" ) - var ( - upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} - ) - - upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - UsernameClaim: "the-user-claim", - GroupsClaim: "the-groups-claim", - Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, - map[string]interface{}{ - "the-user-claim": upstreamUsername, - "the-groups-claim": upstreamGroupMembership, - "other-claim": "should be ignored", - }, - nil - }, - } - - defaultClaimsUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, - map[string]interface{}{ - "sub": upstreamUsername, - "groups": upstreamGroupMembership, - "other-claim": "should be ignored", - }, - nil - }, - } - otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ Name: "other-upstream-idp-name", ClientID: "other-some-client-id", Scopes: []string{"other-scope1", "other-scope2"}, } - failedExchangeUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: upstreamOIDCIdentityProvider.ClientID, - UsernameClaim: upstreamOIDCIdentityProvider.UsernameClaim, - GroupsClaim: upstreamOIDCIdentityProvider.GroupsClaim, - Scopes: upstreamOIDCIdentityProvider.Scopes, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, nil, errors.New("some exchange error") - }, - } - var stateEncoderHashKey = []byte("fake-hash-secret") var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES var cookieEncoderHashKey = []byte("fake-hash-secret2") @@ -210,16 +172,18 @@ func TestCallbackEndpoint(t *testing.T) { path string csrfCookie string - wantStatus int - wantBody string - wantRedirectLocationRegexp string - wantGrantedOpenidScope bool + wantStatus int + wantBody string + wantRedirectLocationRegexp string + wantGrantedOpenidScope bool + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenGroups []string wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -227,11 +191,13 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { - name: "upstream IDP uses default claims", - idp: defaultClaimsUpstreamOIDCIdentityProvider, + name: "upstream IDP provides no username or group claim, so we use default username claim and skip groups", + idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -239,6 +205,7 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", + wantDownstreamIDTokenSubject: upstreamSubject, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -290,7 +257,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), csrfCookie: happyCSRFCookie, @@ -299,7 +266,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's internal version does not match what we want", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongVersionState).String(), csrfCookie: happyCSRFCookie, @@ -308,7 +275,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params element is invalid", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), csrfCookie: happyCSRFCookie, @@ -317,7 +284,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params are missing required value (e.g., client_id)", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(missingClientIDState).String(), csrfCookie: happyCSRFCookie, @@ -326,12 +293,14 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params does not contain openid scope", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -345,7 +314,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "the CSRF cookie does not exist on request", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, @@ -353,7 +322,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", @@ -362,7 +331,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie csrf value does not match state csrf value", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongCSRFValueState).String(), csrfCookie: happyCSRFCookie, @@ -373,7 +342,7 @@ func TestCallbackEndpoint(t *testing.T) { // Upstream exchange { name: "upstream auth code exchange fails", - idp: failedExchangeUpstreamOIDCIdentityProvider, + idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -381,6 +350,16 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, + { + name: "upstream ID token does not contain requested username claim", + idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, } for _, test := range tests { test := test @@ -457,9 +436,13 @@ func TestCallbackEndpoint(t *testing.T) { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) - require.Equal(t, upstreamUsername, storedSession.Claims.Subject) + require.Equal(t, test.wantDownstreamIDTokenSubject, storedSession.Claims.Subject) require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience) - require.Equal(t, upstreamGroupMembership, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) + if test.wantDownstreamIDTokenGroups != nil { + require.Equal(t, test.wantDownstreamIDTokenGroups, storedSession.Claims.Extra["groups"]) + } else { + require.NotContains(t, storedSession.Claims.Extra, "groups") + } } else { require.Empty(t, rsp.Header().Values("Location")) } @@ -519,6 +502,63 @@ func (r *requestPath) String() string { return path + params.Encode() } +type upstreamOIDCIdentityProviderBuilder struct { + idToken map[string]interface{} + usernameClaim, groupsClaim string + authcodeExchangeErr error +} + +func happyUpstream() *upstreamOIDCIdentityProviderBuilder { + return &upstreamOIDCIdentityProviderBuilder{ + usernameClaim: upstreamUsernameClaim, + groupsClaim: upstreamGroupsClaim, + idToken: map[string]interface{}{ + "sub": upstreamSubject, + upstreamUsernameClaim: upstreamUsername, + upstreamGroupsClaim: upstreamGroupMembership, + "other-claim": "should be ignored", + }, + } +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder { + u.usernameClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder { + u.groupsClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name, value string) *upstreamOIDCIdentityProviderBuilder { + u.idToken[name] = value + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder { + delete(u.idToken, claim) + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder { + u.authcodeExchangeErr = err + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCIdentityProvider { + return testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + UsernameClaim: u.usernameClaim, + GroupsClaim: u.groupsClaim, + Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, u.idToken, u.authcodeExchangeErr + }, + } +} + func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { copied := url.Values{} for key, value := range query { From b49d37ca5497f29ffdde3141081f91b77e4c69e0 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 15:53:21 -0500 Subject: [PATCH 17/57] callback_handler.go: test invalid upstream ID token username/groups Signed-off-by: Ryan Richard --- internal/oidc/callback/callback_handler.go | 36 +++++++++++++++---- .../oidc/callback/callback_handler_test.go | 32 ++++++++++++++++- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index ee9a56216..6aea05c93 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -82,7 +82,11 @@ func NewHandler( return err } - groups := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + if err != nil { + return err + } + openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { @@ -209,7 +213,13 @@ func getUsernameFromUpstreamIDToken( username, ok := usernameAsInterface.(string) if !ok { - panic("todo bbb") // TODO + plog.Warning( + "username claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), + "usernameClaim", usernameClaim, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") } return username, nil @@ -218,23 +228,35 @@ func getUsernameFromUpstreamIDToken( func getGroupsFromUpstreamIDToken( upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, idTokenClaims map[string]interface{}, -) []string { +) ([]string, error) { groupsClaim := upstreamIDPConfig.GetGroupsClaim() if groupsClaim == "" { - return nil + return nil, nil } groupsAsInterface, ok := idTokenClaims[groupsClaim] if !ok { - panic("todo ccc") // TODO + plog.Warning( + "no groups claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), + "groupsClaim", groupsClaim, + ) + return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token") } groups, ok := groupsAsInterface.([]string) if !ok { - panic("todo ddd") // TODO + plog.Warning( + "groups claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), + "groupsClaim", groupsClaim, + ) + return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format") } - return groups + return groups, nil } func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 6902773bb..f23198935 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -360,6 +360,36 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, + { + name: "upstream ID token does not contain requested groups claim", + idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token contains username claim with weird format", + idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token contains groups claim with weird format", + idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, } for _, test := range tests { test := test @@ -531,7 +561,7 @@ func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDC return u } -func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name, value string) *upstreamOIDCIdentityProviderBuilder { +func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *upstreamOIDCIdentityProviderBuilder { u.idToken[name] = value return u } From b25696a1fb1ff07bef77a335aa77666e836ec6f2 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 19 Nov 2020 17:57:07 -0800 Subject: [PATCH 18/57] callback_handler.go: Prepend iss to sub when making default username - Also handle several more error cases - Move RequireTimeInDelta to shared testutils package so other tests can also use it - Move all of the oidc test helpers into a new oidc/oidctestutils package to break a circular import dependency. The shared testutil package can't depend on any of our other packages or else we end up with circular dependencies. - Lots more assertions about what was stored at the end of the request to build confidence that we are going to pass all of the right settings over to the token endpoint through the storage, and also to avoid accidental regressions in that area in the future Signed-off-by: Ryan Richard --- internal/oidc/auth/auth_handler_test.go | 68 ++-- internal/oidc/callback/callback_handler.go | 50 ++- .../oidc/callback/callback_handler_test.go | 380 +++++++++++------- .../{testutil => oidc/oidctestutil}/oidc.go | 2 +- .../oidc/provider/manager/manager_test.go | 5 +- internal/oidcclient/login_test.go | 18 +- internal/testutil/assertions.go | 24 ++ 7 files changed, 345 insertions(+), 202 deletions(-) rename internal/{testutil => oidc/oidctestutil}/oidc.go (99%) create mode 100644 internal/testutil/assertions.go diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 9bfe06e54..6ff4ce619 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -20,10 +20,10 @@ import ( "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" - "go.pinniped.dev/internal/testutil" ) func TestAuthorizationEndpoint(t *testing.T) { @@ -114,7 +114,7 @@ func TestAuthorizationEndpoint(t *testing.T) { upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) - upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "some-idp", ClientID: "some-client-id", AuthorizationURL: *upstreamAuthURL, @@ -210,7 +210,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrf = csrfValueOverride } encoded, err := happyStateEncoder.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ + oidctestutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), N: happyNonce, C: csrf, @@ -270,7 +270,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET without a CSRF cookie", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -288,7 +288,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET with a CSRF cookie", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -306,7 +306,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using POST", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -326,7 +326,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -348,7 +348,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream redirect uri does not match what is configured for client", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -365,7 +365,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream client does not exist", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -380,7 +380,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "response type is unsupported", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -396,7 +396,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream scopes do not match what is configured for client", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -412,7 +412,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing response type in request", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -428,7 +428,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing client id in request", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -443,7 +443,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -459,7 +459,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -475,7 +475,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -491,7 +491,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -509,7 +509,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -525,7 +525,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "OIDC validations are skipped when the openid scope was not requested", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -546,7 +546,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "state does not have enough entropy", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -562,7 +562,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding upstream state param", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -577,7 +577,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding CSRF cookie value for new cookie", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -592,7 +592,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating CSRF token", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -607,7 +607,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating nonce", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -622,7 +622,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating PKCE", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -637,7 +637,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while decoding CSRF cookie", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -653,7 +653,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "no upstream providers are configured", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(), // empty + idpListGetter: oidctestutil.NewIDPListGetter(), // empty method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -663,7 +663,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "too many upstream providers are configured", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -673,7 +673,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PUT is a bad method", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -683,7 +683,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PATCH is a bad method", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -693,7 +693,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "DELETE is a bad method", issuer: downstreamIssuer, - idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -772,7 +772,7 @@ func TestAuthorizationEndpoint(t *testing.T) { runOneTestCase(t, test, subject) // Call the setter to change the upstream IDP settings. - newProviderSettings := testutil.TestUpstreamOIDCIdentityProvider{ + newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "some-other-idp", ClientID: "some-other-client-id", AuthorizationURL: *upstreamAuthURL, @@ -840,13 +840,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL expectedQueryStateParam := expectedLocationURL.Query().Get("state") require.NotEmpty(t, expectedQueryStateParam) - var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat + var expectedDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) require.NoError(t, err) actualQueryStateParam := actualLocationURL.Query().Get("state") require.NotEmpty(t, actualQueryStateParam) - var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat + var actualDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) require.NoError(t, err) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 6aea05c93..a9e144b23 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,6 +5,7 @@ package callback import ( + "fmt" "net/http" "net/url" "path" @@ -22,13 +23,22 @@ import ( ) const ( + // The name of the issuer claim specified in the OIDC spec. + idTokenIssuerClaim = "iss" + + // The name of the subject claim specified in the OIDC spec. + idTokenSubjectClaim = "sub" + // defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC // ID token if the upstream OIDC IDP did not tell us to use another claim. - defaultUpstreamUsernameClaim = "sub" + defaultUpstreamUsernameClaim = idTokenSubjectClaim // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // information. downstreamGroupsClaim = "groups" + + // The lifetime of an issued downstream ID token. + downstreamIDTokenLifetime = time.Minute * 5 ) func NewHandler( @@ -90,7 +100,8 @@ func NewHandler( openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { - panic(err) // TODO + plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) + return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err) } oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) @@ -194,9 +205,30 @@ func getUsernameFromUpstreamIDToken( idTokenClaims map[string]interface{}, ) (string, error) { usernameClaim := upstreamIDPConfig.GetUsernameClaim() + + user := "" if usernameClaim == "" { - // TODO: if we use the default "sub" claim, maybe we should create the username with the issuer - // since the spec says the "sub" claim is only unique per issuer. + // The spec says the "sub" claim is only unique per issuer, so by default when there is + // no specific username claim configured we will prepend the issuer string to make it globally unique. + upstreamIssuer := idTokenClaims[idTokenIssuerClaim] + if upstreamIssuer == "" { + plog.Warning( + "issuer claim in upstream ID token missing", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing") + } + upstreamIssuerAsString, ok := upstreamIssuer.(string) + if !ok { + plog.Warning( + "issuer claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") + } + user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim) usernameClaim = defaultUpstreamUsernameClaim } @@ -222,7 +254,7 @@ func getUsernameFromUpstreamIDToken( return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") } - return username, nil + return fmt.Sprintf("%s%s", user, username), nil } func getGroupsFromUpstreamIDToken( @@ -266,10 +298,10 @@ func makeDownstreamSession(issuer, clientID, username string, groups []string) * Issuer: issuer, Subject: username, Audience: []string{clientID}, - ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here - IssuedAt: now, // TODO test this - RequestedAt: now, // TODO test this - AuthTime: now, // TODO test this + ExpiresAt: now.Add(downstreamIDTokenLifetime), + IssuedAt: now, + RequestedAt: now, + AuthTime: now, }, } if groups != nil { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index f23198935..36f51b21d 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -13,6 +13,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/gorilla/securecookie" "github.com/ory/fosite" @@ -21,6 +22,7 @@ import ( "github.com/stretchr/testify/require" "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidcclient" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" @@ -30,26 +32,46 @@ import ( const ( happyUpstreamIDPName = "upstream-idp-name" + upstreamIssuer = "https://my-upstream-issuer.com" upstreamSubject = "abc123-some-guid" upstreamUsername = "test-pinniped-username" upstreamUsernameClaim = "the-user-claim" upstreamGroupsClaim = "the-groups-claim" + + happyDownstreamState = "some-downstream-state" + happyCSRF = "test-csrf" + happyPKCE = "test-pkce" + happyNonce = "test-nonce" + happyStateVersion = "1" + + downstreamIssuer = "https://my-downstream-issuer.com/path" + happyUpstreamAuthcode = "upstream-auth-code" + downstreamRedirectURI = "http://127.0.0.1/callback" + downstreamClientID = "pinniped-cli" + + timeComparisonFudgeFactor = time.Second * 15 ) var ( - upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} + upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} + happyDownstreamScopesRequested = []string{"openid", "profile", "email"} + + happyOriginalRequestParamsQuery = url.Values{ + "response_type": []string{"code"}, + "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, + "client_id": []string{downstreamClientID}, + "state": []string{happyDownstreamState}, + "nonce": []string{"some-nonce-value"}, + "code_challenge": []string{"some-challenge"}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{downstreamRedirectURI}, + } + happyOriginalRequestParams = happyOriginalRequestParamsQuery.Encode() ) func TestCallbackEndpoint(t *testing.T) { - const ( - downstreamIssuer = "https://my-downstream-issuer.com/path" - downstreamRedirectURI = "http://127.0.0.1/callback" - happyUpstreamAuthcode = "upstream-auth-code" - downstreamClientID = "pinniped-cli" - ) - - otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + otherUpstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "other-upstream-idp-name", ClientID: "other-some-client-id", Scopes: []string{"other-scope1", "other-scope2"}, @@ -67,95 +89,13 @@ func TestCallbackEndpoint(t *testing.T) { var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) - happyDownstreamState := "some-downstream-state" - - happyOriginalRequestParamsQuery := url.Values{ - "response_type": []string{"code"}, - "scope": []string{"openid profile email"}, - "client_id": []string{downstreamClientID}, - "state": []string{happyDownstreamState}, - "nonce": []string{"some-nonce-value"}, - "code_challenge": []string{"some-challenge"}, - "code_challenge_method": []string{"S256"}, - "redirect_uri": []string{downstreamRedirectURI}, - } - happyOriginalRequestParams := happyOriginalRequestParamsQuery.Encode() - happyCSRF := "test-csrf" - happyPKCE := "test-pkce" - happyNonce := "test-nonce" - happyStateVersion := "1" - - happyState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: happyOriginalRequestParams, - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: happyStateVersion, - }, - ) - require.NoError(t, err) - - wrongCSRFValueState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: happyOriginalRequestParams, - N: happyNonce, - C: "wrong-csrf-value", - K: happyPKCE, - V: happyStateVersion, - }, - ) - require.NoError(t, err) - - wrongVersionState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: happyOriginalRequestParams, - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: "wrong-state-version", - }, - ) - require.NoError(t, err) - - wrongDownstreamAuthParamsState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: "these-is-not-a-valid-url-query-%z", - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: happyStateVersion, - }, - ) - require.NoError(t, err) - - missingClientIDState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode(), - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: happyStateVersion, - }, - ) - require.NoError(t, err) - - noOpenidScopeState, err := happyStateCodec.Encode("s", - testutil.ExpectedUpstreamStateParamFormat{ - P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode(), - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: happyStateVersion, - }, - ) - require.NoError(t, err) + happyState := happyUpstreamStateParam().Build(t, happyStateCodec) encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue - happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{ + happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{ Authcode: happyUpstreamAuthcode, PKCECodeVerifier: pkce.Code(happyPKCE), ExpectedIDTokenNonce: nonce.Nonce(happyNonce), @@ -167,25 +107,26 @@ func TestCallbackEndpoint(t *testing.T) { tests := []struct { name string - idp testutil.TestUpstreamOIDCIdentityProvider + idp oidctestutil.TestUpstreamOIDCIdentityProvider method string path string csrfCookie string - wantStatus int - wantBody string - wantRedirectLocationRegexp string - wantGrantedOpenidScope bool - wantDownstreamIDTokenSubject string - wantDownstreamIDTokenGroups []string + wantStatus int + wantBody string + wantRedirectLocationRegexp string + wantGrantedOpenidScope bool + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenGroups []string + wantDownstreamRequestedScopes []string - wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs + wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs }{ { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", idp: happyUpstream().Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyRedirectLocationRegexp, @@ -193,22 +134,39 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "", wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { - name: "upstream IDP provides no username or group claim, so we use default username claim and skip groups", + name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups", idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantGrantedOpenidScope: true, + wantBody: "", + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenGroups: nil, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for", + idp: happyUpstream().WithUsernameClaim("sub").Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", wantDownstreamIDTokenSubject: upstreamSubject, + wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, - // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) // Pre-upstream-exchange verification { @@ -264,42 +222,70 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusBadRequest, wantBody: "Bad Request: error reading state\n", }, + { + // This shouldn't happen in practice because the authorize endpoint should have already run the same + // validations, but we would like to test the error handling in this endpoint anyway. + name: "state param contains authorization request params which fail validation", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()). + Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantStatus: http.StatusInternalServerError, + wantBody: "Internal Server Error: error while generating and saving authcode\n", + }, { name: "state's internal version does not match what we want", idp: happyUpstream().Build(), method: http.MethodGet, - path: newRequestPath().WithState(wrongVersionState).String(), + path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: state format version is invalid\n", }, { - name: "state's downstream auth params element is invalid", - idp: happyUpstream().Build(), - method: http.MethodGet, - path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), + name: "state's downstream auth params element is invalid", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam(). + WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z"). + Build(t, happyStateCodec)).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, wantBody: "Bad Request: error reading state downstream auth params\n", }, { - name: "state's downstream auth params are missing required value (e.g., client_id)", - idp: happyUpstream().Build(), - method: http.MethodGet, - path: newRequestPath().WithState(missingClientIDState).String(), + name: "state's downstream auth params are missing required value (e.g., client_id)", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode()). + Build(t, happyStateCodec), + ).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, wantBody: "Bad Request: error using state downstream auth params\n", }, { - name: "state's downstream auth params does not contain openid scope", - idp: happyUpstream().Build(), - method: http.MethodGet, - path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(), + name: "state's downstream auth params does not contain openid scope", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath(). + WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()). + Build(t, happyStateCodec), + ).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamRequestedScopes: []string{"profile", "email"}, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, @@ -333,7 +319,7 @@ func TestCallbackEndpoint(t *testing.T) { name: "cookie csrf value does not match state csrf value", idp: happyUpstream().Build(), method: http.MethodGet, - path: newRequestPath().WithState(wrongCSRFValueState).String(), + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusForbidden, wantBody: "Forbidden: CSRF value does not match\n", @@ -344,7 +330,7 @@ func TestCallbackEndpoint(t *testing.T) { name: "upstream auth code exchange fails", idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadGateway, wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", @@ -354,7 +340,7 @@ func TestCallbackEndpoint(t *testing.T) { name: "upstream ID token does not contain requested username claim", idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", @@ -364,7 +350,7 @@ func TestCallbackEndpoint(t *testing.T) { name: "upstream ID token does not contain requested groups claim", idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n", @@ -374,17 +360,37 @@ func TestCallbackEndpoint(t *testing.T) { name: "upstream ID token contains username claim with weird format", idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, + { + name: "upstream ID token does not contain iss claim when using default username claim config", + idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token has an non-string iss claim when using default username claim config", + idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, { name: "upstream ID token contains groups claim with weird format", idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(), method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", @@ -393,6 +399,7 @@ func TestCallbackEndpoint(t *testing.T) { } for _, test := range tests { test := test + t.Run(test.name, func(t *testing.T) { // Configure fosite the same way that the production code would, except use in-memory storage. // Inject this into our test subject at the last second so we get a fresh storage for every test. @@ -406,7 +413,7 @@ func TestCallbackEndpoint(t *testing.T) { require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) - idpListGetter := testutil.NewIDPListGetter(&test.idp) + idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) subject := NewHandler(downstreamIssuer, idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { @@ -433,7 +440,7 @@ func TestCallbackEndpoint(t *testing.T) { require.Empty(t, rsp.Body.String()) } - if test.wantRedirectLocationRegexp != "" { + if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test // Assert that Location header matches regular expression. require.Len(t, rsp.Header().Values("Location"), 1) actualLocation := rsp.Header().Get("Location") @@ -459,20 +466,63 @@ func TestCallbackEndpoint(t *testing.T) { storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) require.True(t, ok) - // Check various fields of the stored data. + // Check which scopes were granted. if test.wantGrantedOpenidScope { require.Contains(t, storedRequest.GetGrantedScopes(), "openid") } else { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } - require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) - require.Equal(t, test.wantDownstreamIDTokenSubject, storedSession.Claims.Subject) - require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience) + + // Check all the other fields of the stored request. + require.NotEmpty(t, storedRequest.ID) + require.Equal(t, downstreamClientID, storedRequest.Client.GetID()) + require.ElementsMatch(t, test.wantDownstreamRequestedScopes, storedRequest.RequestedScope) + require.Nil(t, storedRequest.RequestedAudience) + require.Empty(t, storedRequest.GrantedAudience) + require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequest.Form) + testutil.RequireTimeInDelta(t, time.Now(), storedRequest.RequestedAt, timeComparisonFudgeFactor) + + // We're not using these fields yet, so confirm that we did not set them (for now). + require.Empty(t, storedSession.Subject) + require.Empty(t, storedSession.Username) + require.Empty(t, storedSession.Headers) + + // The authcode that we are issuing should be good for 15 minutes, which is default for fosite. + testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*15), storedSession.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor) + require.Len(t, storedSession.ExpiresAt, 1) + + // Now confirm the ID token claims. + actualClaims := storedSession.Claims + + // Check the user's identity, which are put into the downstream ID token's subject and groups claims. + require.Equal(t, test.wantDownstreamIDTokenSubject, actualClaims.Subject) if test.wantDownstreamIDTokenGroups != nil { - require.Equal(t, test.wantDownstreamIDTokenGroups, storedSession.Claims.Extra["groups"]) + require.Len(t, actualClaims.Extra, 1) + require.Equal(t, test.wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) } else { - require.NotContains(t, storedSession.Claims.Extra, "groups") + require.Empty(t, actualClaims.Extra) + require.NotContains(t, actualClaims.Extra, "groups") } + + // Check the rest of the downstream ID token's claims. + require.Equal(t, downstreamIssuer, actualClaims.Issuer) + require.Equal(t, []string{downstreamClientID}, actualClaims.Audience) + testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.AuthTime, timeComparisonFudgeFactor) + + // These are not needed yet. + require.Empty(t, actualClaims.JTI) + require.Empty(t, actualClaims.CodeHash) + require.Empty(t, actualClaims.AccessTokenHash) + require.Empty(t, actualClaims.AuthenticationContextClassReference) + require.Empty(t, actualClaims.AuthenticationMethodsReference) + + // TODO we should put the downstream request's nonce into the ID token, but maybe the token endpoint is responsible for that? + require.Empty(t, actualClaims.Nonce) + + // TODO add thorough tests about what should be stored for PKCES and IDSessions } else { require.Empty(t, rsp.Header().Values("Location")) } @@ -486,7 +536,7 @@ type requestPath struct { func newRequestPath() *requestPath { n := happyUpstreamIDPName - c := "1234" + c := happyUpstreamAuthcode s := "4321" return &requestPath{ upstreamIDPName: &n, @@ -532,6 +582,49 @@ func (r *requestPath) String() string { return path + params.Encode() } +type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat + +func happyUpstreamStateParam() *upstreamStateParamBuilder { + return &upstreamStateParamBuilder{ + P: happyOriginalRequestParams, + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: happyStateVersion, + } +} + +func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { + state, err := stateEncoder.Encode("s", b) + require.NoError(t, err) + return state +} + +func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder { + b.P = params + return b +} + +func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder { + b.N = nonce + return b +} + +func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder { + b.C = csrf + return b +} + +func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder { + b.K = pkce + return b +} + +func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder { + b.V = version + return b +} + type upstreamOIDCIdentityProviderBuilder struct { idToken map[string]interface{} usernameClaim, groupsClaim string @@ -543,6 +636,7 @@ func happyUpstream() *upstreamOIDCIdentityProviderBuilder { usernameClaim: upstreamUsernameClaim, groupsClaim: upstreamGroupsClaim, idToken: map[string]interface{}{ + "iss": upstreamIssuer, "sub": upstreamSubject, upstreamUsernameClaim: upstreamUsername, upstreamGroupsClaim: upstreamGroupMembership, @@ -551,6 +645,11 @@ func happyUpstream() *upstreamOIDCIdentityProviderBuilder { } } +func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder { + u.usernameClaim = claim + return u +} + func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder { u.usernameClaim = "" return u @@ -576,8 +675,8 @@ func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeErr return u } -func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCIdentityProvider { - return testutil.TestUpstreamOIDCIdentityProvider{ +func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider { + return oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: happyUpstreamIDPName, ClientID: "some-client-id", UsernameClaim: u.usernameClaim, @@ -592,12 +691,13 @@ func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCI func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { copied := url.Values{} for key, value := range query { - if modification, ok := modifications[key]; ok { - if modification != "" { - copied[key] = []string{modification} - } + copied[key] = value + } + for key, value := range modifications { + if value == "" { + copied.Del(key) } else { - copied[key] = value + copied[key] = []string{value} } } return copied diff --git a/internal/testutil/oidc.go b/internal/oidc/oidctestutil/oidc.go similarity index 99% rename from internal/testutil/oidc.go rename to internal/oidc/oidctestutil/oidc.go index 7cbfcf812..ad4338db7 100644 --- a/internal/testutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -1,7 +1,7 @@ // Copyright 2020 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package testutil +package oidctestutil import ( "context" diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index fce748bb7..86137abd7 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -12,8 +12,6 @@ import ( "strings" "testing" - "go.pinniped.dev/internal/testutil" - "github.com/sclevine/spec" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" @@ -22,6 +20,7 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" ) @@ -109,7 +108,7 @@ func TestManager(t *testing.T) { parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) r.NoError(err) - idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{ + idpListGetter := oidctestutil.NewIDPListGetter(&oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "test-idp", ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, diff --git a/internal/oidcclient/login_test.go b/internal/oidcclient/login_test.go index 35574ba09..8cd4bd2a3 100644 --- a/internal/oidcclient/login_test.go +++ b/internal/oidcclient/login_test.go @@ -26,6 +26,7 @@ import ( "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/state" + "go.pinniped.dev/internal/testutil" ) // mockSessionCache exists to avoid an import cycle if we generate mocks into another package. @@ -481,7 +482,7 @@ func TestLogin(t *testing.T) { require.NotNil(t, tok.AccessToken) require.Equal(t, want.Token, tok.AccessToken.Token) require.Equal(t, want.Type, tok.AccessToken.Type) - requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) + testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) } else { assert.Nil(t, tok.AccessToken) } @@ -489,7 +490,7 @@ func TestLogin(t *testing.T) { if want := tt.wantToken.IDToken; want != nil { require.NotNil(t, tok.IDToken) require.Equal(t, want.Token, tok.IDToken.Token) - requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) + testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) } else { assert.Nil(t, tok.IDToken) } @@ -682,16 +683,3 @@ type mockDiscovery struct{ provider *oidc.Provider } func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } - -func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { - require.InDeltaf(t, - float64(t1.UnixNano()), - float64(t2.UnixNano()), - float64(delta.Nanoseconds()), - "expected %s and %s to be < %s apart, but they are %s apart", - t1.Format(time.RFC3339Nano), - t2.Format(time.RFC3339Nano), - delta.String(), - t1.Sub(t2).String(), - ) -} diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go new file mode 100644 index 000000000..772476028 --- /dev/null +++ b/internal/testutil/assertions.go @@ -0,0 +1,24 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { + require.InDeltaf(t, + float64(t1.UnixNano()), + float64(t2.UnixNano()), + float64(delta.Nanoseconds()), + "expected %s and %s to be < %s apart, but they are %s apart", + t1.Format(time.RFC3339Nano), + t2.Format(time.RFC3339Nano), + delta.String(), + t1.Sub(t2).String(), + ) +} From f8d76066c5d2ebffefbd46c5b3fdda72ececda45 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 08:38:23 -0500 Subject: [PATCH 19/57] callback_handler.go: assert nonce is stored correctly I think we want to do this here since we are storing all of the other ID token claims? Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 11 +++++++++-- internal/oidc/callback/callback_handler_test.go | 12 ++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index a9e144b23..e443b9d6e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -97,7 +97,13 @@ func NewHandler( return err } - openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups) + openIDSession := makeDownstreamSession( + downstreamIssuer, + downstreamAuthParams.Get("client_id"), + downstreamAuthParams.Get("nonce"), + username, + groups, + ) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) @@ -291,7 +297,7 @@ func getGroupsFromUpstreamIDToken( return groups, nil } -func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession { +func makeDownstreamSession(issuer, clientID, nonce, username string, groups []string) *openid.DefaultSession { now := time.Now() openIDSession := &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -302,6 +308,7 @@ func makeDownstreamSession(issuer, clientID, username string, groups []string) * IssuedAt: now, RequestedAt: now, AuthTime: now, + Nonce: nonce, }, } if groups != nil { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 36f51b21d..bcfc3e8f7 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -49,6 +49,7 @@ const ( happyUpstreamAuthcode = "upstream-auth-code" downstreamRedirectURI = "http://127.0.0.1/callback" downstreamClientID = "pinniped-cli" + downstreamNonce = "some-nonce-value" timeComparisonFudgeFactor = time.Second * 15 ) @@ -62,7 +63,7 @@ var ( "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, "client_id": []string{downstreamClientID}, "state": []string{happyDownstreamState}, - "nonce": []string{"some-nonce-value"}, + "nonce": []string{downstreamNonce}, "code_challenge": []string{"some-challenge"}, "code_challenge_method": []string{"S256"}, "redirect_uri": []string{downstreamRedirectURI}, @@ -119,6 +120,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenSubject string wantDownstreamIDTokenGroups []string wantDownstreamRequestedScopes []string + wantDownstreamNonce string wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs }{ @@ -135,6 +137,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -150,6 +153,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenGroups: nil, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -165,6 +169,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenSubject: upstreamSubject, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, @@ -287,6 +292,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamRequestedScopes: []string{"profile", "email"}, wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamNonce: downstreamNonce, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -507,6 +513,7 @@ func TestCallbackEndpoint(t *testing.T) { // Check the rest of the downstream ID token's claims. require.Equal(t, downstreamIssuer, actualClaims.Issuer) require.Equal(t, []string{downstreamClientID}, actualClaims.Audience) + require.Equal(t, test.wantDownstreamNonce, actualClaims.Nonce) testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor) testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor) testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor) @@ -519,9 +526,6 @@ func TestCallbackEndpoint(t *testing.T) { require.Empty(t, actualClaims.AuthenticationContextClassReference) require.Empty(t, actualClaims.AuthenticationMethodsReference) - // TODO we should put the downstream request's nonce into the ID token, but maybe the token endpoint is responsible for that? - require.Empty(t, actualClaims.Nonce) - // TODO add thorough tests about what should be stored for PKCES and IDSessions } else { require.Empty(t, rsp.Header().Values("Location")) From 8f5d1709a1caefeac2dc3adee4e1b9b3e3ece1bd Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 09:41:49 -0500 Subject: [PATCH 20/57] callback_handler.go: assert behavior about PKCE and IDSession storage Also aggresively refactor for readability: - Make helper validations functions for each type of storage - Try to label symbols based on their downstream/upstream use and group them accordingly Signed-off-by: Andrew Keesler --- internal/oidc/callback/callback_handler.go | 2 +- .../oidc/callback/callback_handler_test.go | 331 ++++++++++++------ 2 files changed, 227 insertions(+), 106 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index e443b9d6e..24780869c 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -304,11 +304,11 @@ func makeDownstreamSession(issuer, clientID, nonce, username string, groups []st Issuer: issuer, Subject: username, Audience: []string{clientID}, + Nonce: nonce, ExpiresAt: now.Add(downstreamIDTokenLifetime), IssuedAt: now, RequestedAt: now, AuthTime: now, - Nonce: nonce, }, } if groups != nil { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index bcfc3e8f7..588eddf95 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -39,17 +39,20 @@ const ( upstreamUsernameClaim = "the-user-claim" upstreamGroupsClaim = "the-groups-claim" - happyDownstreamState = "some-downstream-state" - happyCSRF = "test-csrf" - happyPKCE = "test-pkce" - happyNonce = "test-nonce" - happyStateVersion = "1" - - downstreamIssuer = "https://my-downstream-issuer.com/path" happyUpstreamAuthcode = "upstream-auth-code" - downstreamRedirectURI = "http://127.0.0.1/callback" - downstreamClientID = "pinniped-cli" - downstreamNonce = "some-nonce-value" + + happyDownstreamState = "some-downstream-state" + happyDownstreamCSRF = "test-csrf" + happyDownstreamPKCE = "test-pkce" + happyDownstreamNonce = "test-nonce" + happyDownstreamStateVersion = "1" + + downstreamIssuer = "https://my-downstream-issuer.com/path" + downstreamRedirectURI = "http://127.0.0.1/callback" + downstreamClientID = "pinniped-cli" + downstreamNonce = "some-nonce-value" + downstreamPKCEChallenge = "some-challenge" + downstreamPKCEChallengeMethod = "S256" timeComparisonFudgeFactor = time.Second * 15 ) @@ -58,17 +61,17 @@ var ( upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} happyDownstreamScopesRequested = []string{"openid", "profile", "email"} - happyOriginalRequestParamsQuery = url.Values{ + happyDownstreamRequestParamsQuery = url.Values{ "response_type": []string{"code"}, "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, "client_id": []string{downstreamClientID}, "state": []string{happyDownstreamState}, "nonce": []string{downstreamNonce}, - "code_challenge": []string{"some-challenge"}, - "code_challenge_method": []string{"S256"}, + "code_challenge": []string{downstreamPKCEChallenge}, + "code_challenge_method": []string{downstreamPKCEChallengeMethod}, "redirect_uri": []string{downstreamRedirectURI}, } - happyOriginalRequestParams = happyOriginalRequestParamsQuery.Encode() + happyDownstreamRequestParams = happyDownstreamRequestParamsQuery.Encode() ) func TestCallbackEndpoint(t *testing.T) { @@ -92,18 +95,18 @@ func TestCallbackEndpoint(t *testing.T) { happyState := happyUpstreamStateParam().Build(t, happyStateCodec) - encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF) + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{ Authcode: happyUpstreamAuthcode, - PKCECodeVerifier: pkce.Code(happyPKCE), - ExpectedIDTokenNonce: nonce.Nonce(happyNonce), + PKCECodeVerifier: pkce.Code(happyDownstreamPKCE), + ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), } // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it - happyRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState + happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState tests := []struct { name string @@ -113,14 +116,16 @@ func TestCallbackEndpoint(t *testing.T) { path string csrfCookie string - wantStatus int - wantBody string - wantRedirectLocationRegexp string - wantGrantedOpenidScope bool - wantDownstreamIDTokenSubject string - wantDownstreamIDTokenGroups []string - wantDownstreamRequestedScopes []string - wantDownstreamNonce string + wantStatus int + wantBody string + wantRedirectLocationRegexp string + wantGrantedOpenidScope bool + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenGroups []string + wantDownstreamRequestedScopes []string + wantDownstreamNonce string + wantDownstreamPKCEChallenge string + wantDownstreamPKCEChallengeMethod string wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs }{ @@ -131,13 +136,15 @@ func TestCallbackEndpoint(t *testing.T) { path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, - wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -147,13 +154,15 @@ func TestCallbackEndpoint(t *testing.T) { path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, - wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenGroups: nil, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -163,13 +172,15 @@ func TestCallbackEndpoint(t *testing.T) { path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, - wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", wantDownstreamIDTokenSubject: upstreamSubject, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, @@ -235,7 +246,7 @@ func TestCallbackEndpoint(t *testing.T) { method: http.MethodGet, path: newRequestPath().WithState( happyUpstreamStateParam(). - WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()). Build(t, happyStateCodec), ).String(), csrfCookie: happyCSRFCookie, @@ -269,7 +280,7 @@ func TestCallbackEndpoint(t *testing.T) { method: http.MethodGet, path: newRequestPath().WithState( happyUpstreamStateParam(). - WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode()). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": ""}).Encode()). Build(t, happyStateCodec), ).String(), csrfCookie: happyCSRFCookie, @@ -283,7 +294,7 @@ func TestCallbackEndpoint(t *testing.T) { path: newRequestPath(). WithState( happyUpstreamStateParam(). - WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()). Build(t, happyStateCodec), ).String(), csrfCookie: happyCSRFCookie, @@ -293,6 +304,8 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamRequestedScopes: []string{"profile", "email"}, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -455,80 +468,40 @@ func TestCallbackEndpoint(t *testing.T) { require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation) capturedAuthCode := submatches[1] - // One authcode should have been stored. - require.Len(t, oauthStore.AuthorizeCodes, 1) - // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") require.Len(t, authcodeDataAndSignature, 2) - // Get the authcode session back from storage so we can require that it was stored correctly. - storedAuthorizeRequest, err := oauthStore.GetAuthorizeCodeSession(context.Background(), authcodeDataAndSignature[1], nil) - require.NoError(t, err) + storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage( + t, + oauthStore, + authcodeDataAndSignature[1], // Authcode store key is authcode signature + test.wantGrantedOpenidScope, + test.wantDownstreamIDTokenSubject, + test.wantDownstreamIDTokenGroups, + test.wantDownstreamRequestedScopes, + test.wantDownstreamNonce, + ) - // Check that storage returned the expected concrete data types. - storedRequest, ok := storedAuthorizeRequest.(*fosite.Request) - require.True(t, ok) - storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) - require.True(t, ok) + validatePKCEStorage( + t, + oauthStore, + authcodeDataAndSignature[1], // PKCE store key is authcode signature + storedRequestFromAuthcode, + storedSessionFromAuthcode, + test.wantDownstreamPKCEChallenge, + test.wantDownstreamPKCEChallengeMethod, + ) - // Check which scopes were granted. - if test.wantGrantedOpenidScope { - require.Contains(t, storedRequest.GetGrantedScopes(), "openid") - } else { - require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") - } - - // Check all the other fields of the stored request. - require.NotEmpty(t, storedRequest.ID) - require.Equal(t, downstreamClientID, storedRequest.Client.GetID()) - require.ElementsMatch(t, test.wantDownstreamRequestedScopes, storedRequest.RequestedScope) - require.Nil(t, storedRequest.RequestedAudience) - require.Empty(t, storedRequest.GrantedAudience) - require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequest.Form) - testutil.RequireTimeInDelta(t, time.Now(), storedRequest.RequestedAt, timeComparisonFudgeFactor) - - // We're not using these fields yet, so confirm that we did not set them (for now). - require.Empty(t, storedSession.Subject) - require.Empty(t, storedSession.Username) - require.Empty(t, storedSession.Headers) - - // The authcode that we are issuing should be good for 15 minutes, which is default for fosite. - testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*15), storedSession.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor) - require.Len(t, storedSession.ExpiresAt, 1) - - // Now confirm the ID token claims. - actualClaims := storedSession.Claims - - // Check the user's identity, which are put into the downstream ID token's subject and groups claims. - require.Equal(t, test.wantDownstreamIDTokenSubject, actualClaims.Subject) - if test.wantDownstreamIDTokenGroups != nil { - require.Len(t, actualClaims.Extra, 1) - require.Equal(t, test.wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) - } else { - require.Empty(t, actualClaims.Extra) - require.NotContains(t, actualClaims.Extra, "groups") - } - - // Check the rest of the downstream ID token's claims. - require.Equal(t, downstreamIssuer, actualClaims.Issuer) - require.Equal(t, []string{downstreamClientID}, actualClaims.Audience) - require.Equal(t, test.wantDownstreamNonce, actualClaims.Nonce) - testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.AuthTime, timeComparisonFudgeFactor) - - // These are not needed yet. - require.Empty(t, actualClaims.JTI) - require.Empty(t, actualClaims.CodeHash) - require.Empty(t, actualClaims.AccessTokenHash) - require.Empty(t, actualClaims.AuthenticationContextClassReference) - require.Empty(t, actualClaims.AuthenticationMethodsReference) - - // TODO add thorough tests about what should be stored for PKCES and IDSessions - } else { - require.Empty(t, rsp.Header().Values("Location")) + validateIDSessionStorage( + t, + oauthStore, + capturedAuthCode, // IDSession store key is full authcode + storedRequestFromAuthcode, + storedSessionFromAuthcode, + test.wantGrantedOpenidScope, + test.wantDownstreamNonce, + ) } }) } @@ -590,11 +563,11 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat func happyUpstreamStateParam() *upstreamStateParamBuilder { return &upstreamStateParamBuilder{ - P: happyOriginalRequestParams, - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: happyStateVersion, + P: happyDownstreamRequestParams, + N: happyDownstreamNonce, + C: happyDownstreamCSRF, + K: happyDownstreamPKCE, + V: happyDownstreamStateVersion, } } @@ -706,3 +679,151 @@ func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string } return copied } + +func validateAuthcodeStorage( + t *testing.T, + oauthStore *storage.MemoryStore, + storeKey string, + wantGrantedOpenidScope bool, + wantDownstreamIDTokenSubject string, + wantDownstreamIDTokenGroups []string, + wantDownstreamRequestedScopes []string, + wantDownstreamNonce string, +) (*fosite.Request, *openid.DefaultSession) { + t.Helper() + + // One authcode should have been stored. + require.Len(t, oauthStore.AuthorizeCodes, 1) + + // Get the authcode session back from storage so we can require that it was stored correctly. + storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromAuthcode, storedSessionFromAuthcode := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromAuthcode) + + // Check which scopes were granted. + if wantGrantedOpenidScope { + require.Contains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid") + } else { + require.NotContains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid") + } + + // Check all the other fields of the stored request. + require.NotEmpty(t, storedRequestFromAuthcode.ID) + require.Equal(t, downstreamClientID, storedRequestFromAuthcode.Client.GetID()) + require.ElementsMatch(t, wantDownstreamRequestedScopes, storedRequestFromAuthcode.RequestedScope) + require.Nil(t, storedRequestFromAuthcode.RequestedAudience) + require.Empty(t, storedRequestFromAuthcode.GrantedAudience) + require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequestFromAuthcode.Form) + testutil.RequireTimeInDelta(t, time.Now(), storedRequestFromAuthcode.RequestedAt, timeComparisonFudgeFactor) + + // We're not using these fields yet, so confirm that we did not set them (for now). + require.Empty(t, storedSessionFromAuthcode.Subject) + require.Empty(t, storedSessionFromAuthcode.Username) + require.Empty(t, storedSessionFromAuthcode.Headers) + + // The authcode that we are issuing should be good for 15 minutes, which is default for fosite. + testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*15), storedSessionFromAuthcode.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor) + require.Len(t, storedSessionFromAuthcode.ExpiresAt, 1) + + // Now confirm the ID token claims. + actualClaims := storedSessionFromAuthcode.Claims + + // Check the user's identity, which are put into the downstream ID token's subject and groups claims. + require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) + if wantDownstreamIDTokenGroups != nil { + require.Len(t, actualClaims.Extra, 1) + require.Equal(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) + } else { + require.Empty(t, actualClaims.Extra) + require.NotContains(t, actualClaims.Extra, "groups") + } + + // Check the rest of the downstream ID token's claims. + require.Equal(t, downstreamIssuer, actualClaims.Issuer) + require.Equal(t, []string{downstreamClientID}, actualClaims.Audience) + require.Equal(t, wantDownstreamNonce, actualClaims.Nonce) + testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now(), actualClaims.AuthTime, timeComparisonFudgeFactor) + + // These are not needed yet. + require.Empty(t, actualClaims.JTI) + require.Empty(t, actualClaims.CodeHash) + require.Empty(t, actualClaims.AccessTokenHash) + require.Empty(t, actualClaims.AuthenticationContextClassReference) + require.Empty(t, actualClaims.AuthenticationMethodsReference) + + return storedRequestFromAuthcode, storedSessionFromAuthcode +} + +func validatePKCEStorage( + t *testing.T, + oauthStore *storage.MemoryStore, + storeKey string, + storedRequestFromAuthcode *fosite.Request, + storedSessionFromAuthcode *openid.DefaultSession, + wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod string, +) { + t.Helper() + + // One PKCE should have been stored. + require.Len(t, oauthStore.PKCES, 1) + storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromPKCE, storedSessionFromPKCE := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromPKCE) + + // The stored PKCE request should be the same as the stored authcode request. + require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromPKCE.ID) + require.Equal(t, storedSessionFromAuthcode, storedSessionFromPKCE) + + // The stored PKCE request should also contain the PKCE challenge that the downstream sent us. + require.Equal(t, wantDownstreamPKCEChallenge, storedRequestFromPKCE.Form.Get("code_challenge")) + require.Equal(t, wantDownstreamPKCEChallengeMethod, storedRequestFromPKCE.Form.Get("code_challenge_method")) +} + +func validateIDSessionStorage( + t *testing.T, + oauthStore *storage.MemoryStore, + storeKey string, + storedRequestFromAuthcode *fosite.Request, + storedSessionFromAuthcode *openid.DefaultSession, + wantGrantedOpenidScope bool, + wantDownstreamNonce string, +) { + t.Helper() + + // One IDSession should have been stored, if the downstream actually requested the "openid" scope.. + if wantGrantedOpenidScope { + require.Len(t, oauthStore.IDSessions, 1) + storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession) + + // The stored IDSession request should be the same as the stored authcode request. + require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID) + require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession) + + // The stored IDSession request should also contain the nonce that the downstream sent us. + require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce")) + } else { + require.Len(t, oauthStore.IDSessions, 0) + } +} + +func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *openid.DefaultSession) { + t.Helper() + + storedRequest, ok := storedAuthorizeRequest.(*fosite.Request) + require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest, &fosite.Request{}) + storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) + require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest.GetSession(), &openid.DefaultSession{}) + + return storedRequest, storedSession +} From 488d1b663a7a8de85c91312b85f704ca311fc745 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 10:42:43 -0500 Subject: [PATCH 21/57] internal/oidc/provider/manager: route to callback endpoint Signed-off-by: Andrew Keesler --- internal/oidc/oidc.go | 1 + internal/oidc/provider/manager/manager.go | 4 +++ .../oidc/provider/manager/manager_test.go | 28 +++++++++++++++++-- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 4da9cfcdf..8b73e6626 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -18,6 +18,7 @@ const ( WellKnownEndpointPath = "/.well-known/openid-configuration" AuthorizationEndpointPath = "/oauth2/authorize" TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential + CallbackEndpointPath = "/callback" JWKSEndpointPath = "/jwks.json" ) diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index f009693d0..deb1cfc4a 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -12,6 +12,7 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" + "go.pinniped.dev/internal/oidc/callback" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" @@ -84,6 +85,9 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) + callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath + m.providerHandlers[callbackURL] = callback.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, encoder, encoder) + plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 86137abd7..fdea39d5f 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -42,6 +42,7 @@ func TestManager(t *testing.T) { issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2KeyID = "issuer2-key" upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" + downstreamRedirectURL = "http://127.0.0.1:12345/callback" ) newGetRequest := func(url string) *http.Request { @@ -82,6 +83,18 @@ func TestManager(t *testing.T) { ) } + requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) { + recorder := httptest.NewRecorder() + + subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix)) + + r.False(fallbackHandlerWasCalled) + + // Minimal check to ensure that the right endpoint was called - when we don't send a CSRF + // cookie to the callback endpoint, the callback endpoint responds with a 403. + r.Equal(http.StatusForbidden, recorder.Code) + } + requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { recorder := httptest.NewRecorder() @@ -162,7 +175,6 @@ func TestManager(t *testing.T) { requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) - authRedirectURI := "http://127.0.0.1/callback" authRequestParams := "?" + url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, @@ -171,7 +183,7 @@ func TestManager(t *testing.T) { "nonce": []string{"some-nonce-value"}, "code_challenge": []string{"some-challenge"}, "code_challenge_method": []string{"S256"}, - "redirect_uri": []string{authRedirectURI}, + "redirect_uri": []string{downstreamRedirectURL}, }.Encode() requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) @@ -180,6 +192,18 @@ func TestManager(t *testing.T) { // Hostnames are case-insensitive, so test that we can handle that. requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + + callbackRequestParams := "?" + url.Values{ + "code": []string{"some-code"}, + "state": []string{"some-state-value"}, + }.Encode() + + requireCallbackRequestToBeHandled(issuer1, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer2, callbackRequestParams) + + // // Hostnames are case-insensitive, so test that we can handle that. + requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams) } when("given some valid providers via SetProviders()", func() { From 541019eb9888bab1a332c7b5a40093517c502de7 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 15:36:51 -0500 Subject: [PATCH 22/57] callback_handler.go: simplify stored ID token claims Fosite is gonna set these fields for us. Signed-off-by: Ryan Richard --- internal/oidc/auth/auth_handler_test.go | 2 +- internal/oidc/callback/callback_handler.go | 21 ++----------- .../oidc/callback/callback_handler_test.go | 30 +++++++++++-------- internal/oidc/oidc.go | 3 +- internal/oidc/provider/manager/manager.go | 4 +-- 5 files changed, 26 insertions(+), 34 deletions(-) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 6ff4ce619..4003f9c27 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -125,7 +125,7 @@ func TestAuthorizationEndpoint(t *testing.T) { oauthStore := oidc.NullStorage{} hmacSecret := []byte("some secret - must have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") - oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) + oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) happyCSRF := "test-csrf" happyPKCE := "test-pkce" diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 24780869c..5b725623b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -36,13 +36,9 @@ const ( // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // information. downstreamGroupsClaim = "groups" - - // The lifetime of an issued downstream ID token. - downstreamIDTokenLifetime = time.Minute * 5 ) func NewHandler( - downstreamIssuer string, idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder, @@ -97,13 +93,7 @@ func NewHandler( return err } - openIDSession := makeDownstreamSession( - downstreamIssuer, - downstreamAuthParams.Get("client_id"), - downstreamAuthParams.Get("nonce"), - username, - groups, - ) + openIDSession := makeDownstreamSession(username, groups) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) @@ -297,16 +287,11 @@ func getGroupsFromUpstreamIDToken( return groups, nil } -func makeDownstreamSession(issuer, clientID, nonce, username string, groups []string) *openid.DefaultSession { - now := time.Now() +func makeDownstreamSession(username string, groups []string) *openid.DefaultSession { + now := time.Now().UTC() openIDSession := &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ - Issuer: issuer, Subject: username, - Audience: []string{clientID}, - Nonce: nonce, - ExpiresAt: now.Add(downstreamIDTokenLifetime), - IssuedAt: now, RequestedAt: now, AuthTime: now, }, diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 588eddf95..5f1d16ba3 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -430,10 +430,10 @@ func TestCallbackEndpoint(t *testing.T) { } hmacSecret := []byte("some secret - must have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") - oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) + oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) - subject := NewHandler(downstreamIssuer, idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -480,7 +480,6 @@ func TestCallbackEndpoint(t *testing.T) { test.wantDownstreamIDTokenSubject, test.wantDownstreamIDTokenGroups, test.wantDownstreamRequestedScopes, - test.wantDownstreamNonce, ) validatePKCEStorage( @@ -688,7 +687,6 @@ func validateAuthcodeStorage( wantDownstreamIDTokenSubject string, wantDownstreamIDTokenGroups []string, wantDownstreamRequestedScopes []string, - wantDownstreamNonce string, ) (*fosite.Request, *openid.DefaultSession) { t.Helper() @@ -740,14 +738,22 @@ func validateAuthcodeStorage( require.NotContains(t, actualClaims.Extra, "groups") } - // Check the rest of the downstream ID token's claims. - require.Equal(t, downstreamIssuer, actualClaims.Issuer) - require.Equal(t, []string{downstreamClientID}, actualClaims.Audience) - require.Equal(t, wantDownstreamNonce, actualClaims.Nonce) - testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor) - testutil.RequireTimeInDelta(t, time.Now(), actualClaims.AuthTime, timeComparisonFudgeFactor) + // Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time). + testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.AuthTime, timeComparisonFudgeFactor) + requestedAtZone, _ := actualClaims.RequestedAt.Zone() + require.Equal(t, "UTC", requestedAtZone) + authTimeZone, _ := actualClaims.AuthTime.Zone() + require.Equal(t, "UTC", authTimeZone) + + // Fosite will set these fields for us in the token endpoint based on the store session + // information. Therefore, we assert that they are empty because we want the library to do the + // lifting for us. + require.Empty(t, actualClaims.Issuer) + require.Nil(t, actualClaims.Audience) + require.Empty(t, actualClaims.Nonce) + require.Zero(t, actualClaims.ExpiresAt) + require.Zero(t, actualClaims.IssuedAt) // These are not needed yet. require.Empty(t, actualClaims.JTI) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 8b73e6626..017b92499 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -85,9 +85,10 @@ func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { } } -func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []byte) fosite.OAuth2Provider { +func FositeOauth2Helper(oauthStore interface{}, issuer string, hmacSecretOfLengthAtLeast32 []byte) fosite.OAuth2Provider { oauthConfig := &compose.Config{ EnforcePKCEForPublicClients: true, + IDTokenIssuer: issuer, } return compose.Compose( diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index deb1cfc4a..7c6403b97 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -71,7 +71,7 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { // Use NullStorage for the authorize endpoint because we do not actually want to store anything until // the upstream callback endpoint is called later. - oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret + oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, incomingProvider.Issuer(), []byte("some secret - must have at least 32 bytes")) // TODO replace this secret // TODO use different codecs for the state and the cookie, because: // 1. we would like to state to have an embedded expiration date while the cookie does not need that @@ -86,7 +86,7 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath - m.providerHandlers[callbackURL] = callback.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, encoder, encoder) + m.providerHandlers[callbackURL] = callback.NewHandler(m.idpListGetter, oauthHelper, encoder, encoder) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } From 72321fc106e5673b3b52496b5747005fa97dbb09 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Fri, 20 Nov 2020 16:14:45 -0500 Subject: [PATCH 23/57] Use /callback (without IDP name) path for callback endpoint (part 1) This is much nicer UX for an administrator installing a UpstreamOIDCProvider CRD. They don't have to guess as hard at what the callback endpoint path should be for their UpstreamOIDCProvider. Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler.go | 13 +++++++++++-- internal/oidc/auth/auth_handler_test.go | 23 ++++++++++++++--------- internal/oidc/oidc.go | 1 + internal/oidc/oidctestutil/oidc.go | 1 + 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 08c42ada8..8a566620d 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -92,11 +92,18 @@ func NewHandler( Endpoint: oauth2.Endpoint{ AuthURL: upstreamIDP.GetAuthorizationURL().String(), }, - RedirectURL: fmt.Sprintf("%s/callback/%s", downstreamIssuer, upstreamIDP.GetName()), + RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer), Scopes: upstreamIDP.GetScopes(), } - encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) + encodedStateParamValue, err := upstreamStateParam( + authorizeRequester, + upstreamIDP.GetName(), + nonceValue, + csrfValue, + pkceValue, + upstreamStateEncoder, + ) if err != nil { plog.Error("authorize upstream state param error", err) return err @@ -188,6 +195,7 @@ func generateValues( func upstreamStateParam( authorizeRequester fosite.AuthorizeRequester, + upstreamName string, nonceValue nonce.Nonce, csrfValue csrftoken.CSRFToken, pkceValue pkce.Code, @@ -195,6 +203,7 @@ func upstreamStateParam( ) (string, error) { stateParamData := oidc.UpstreamStateParamData{ AuthParams: authorizeRequester.GetRequestForm().Encode(), + UpstreamName: upstreamName, Nonce: nonceValue, CSRFToken: csrfValue, PKCECode: pkceValue, diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 4003f9c27..381ff052c 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -204,14 +204,19 @@ func TestAuthorizationEndpoint(t *testing.T) { return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) } - expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string { + expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride, upstreamNameOverride string) string { csrf := happyCSRF if csrfValueOverride != "" { csrf = csrfValueOverride } + upstreamName := upstreamOIDCIdentityProvider.Name + if upstreamNameOverride != "" { + upstreamName = upstreamNameOverride + } encoded, err := happyStateEncoder.Encode("s", oidctestutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), + U: upstreamName, N: happyNonce, C: csrf, K: happyPKCE, @@ -232,7 +237,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": downstreamIssuer + "/callback/some-idp", + "redirect_uri": downstreamIssuer + "/callback", }) } @@ -281,7 +286,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -299,7 +304,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -320,7 +325,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "", wantBodyString: "", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, }, { @@ -341,7 +346,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client - }, "")), + }, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -538,7 +543,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( - map[string]string{"prompt": "none login", "scope": "email"}, "", + map[string]string{"prompt": "none login", "scope": "email"}, "", "", )), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, @@ -787,11 +792,11 @@ func TestAuthorizationEndpoint(t *testing.T) { "access_type": "offline", "scope": "other-scope1 other-scope2", "client_id": "some-other-client-id", - "state": expectedUpstreamStateParam(nil, ""), + "state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": downstreamIssuer + "/callback/some-other-idp", + "redirect_uri": downstreamIssuer + "/callback", }, ) test.wantBodyString = fmt.Sprintf(`Found.%s`, diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 017b92499..ee69ee13d 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -66,6 +66,7 @@ type Codec interface { // the state param. type UpstreamStateParamData struct { AuthParams string `json:"p"` + UpstreamName string `json:"u"` Nonce nonce.Nonce `json:"n"` CSRFToken csrftoken.CSRFToken `json:"c"` PKCECode pkce.Code `json:"k"` diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index ad4338db7..fc8c30924 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -112,6 +112,7 @@ func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentity // assertions about the redirect URL in this test. type ExpectedUpstreamStateParamFormat struct { P string `json:"p"` + U string `json:"u"` N string `json:"n"` C string `json:"c"` K string `json:"k"` From b21f0035d7fd710539788f01b2727014941d7784 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 13:33:08 -0800 Subject: [PATCH 24/57] callback_handler.go: Get upstream name from state instead of path Also use ConstantTimeCompare() to compare CSRF tokens to prevent leaking any information in how quickly we reject bad tokens. Signed-off-by: Ryan Richard --- internal/oidc/callback/callback_handler.go | 11 +++++------ internal/oidc/callback/callback_handler_test.go | 17 +++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 5b725623b..f237726b6 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,10 +5,10 @@ package callback import ( + "crypto/subtle" "fmt" "net/http" "net/url" - "path" "time" "github.com/ory/fosite" @@ -49,7 +49,7 @@ func NewHandler( return err } - upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter) + upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter) if upstreamIDPConfig == nil { plog.Warning("upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") @@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, err } - if state.CSRFToken != csrfValue { + if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 { plog.InfoErr("CSRF value does not match", err) return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) } @@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return state, nil } -func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { - _, lastPathComponent := path.Split(r.URL.Path) +func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { for _, p := range idpListGetter.GetIDPList() { - if p.GetName() == lastPathComponent { + if p.GetName() == upstreamName { return p } } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 5f1d16ba3..c51ad86a0 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -6,7 +6,6 @@ package callback import ( "context" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" @@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) { } type requestPath struct { - upstreamIDPName, code, state *string + code, state *string } func newRequestPath() *requestPath { - n := happyUpstreamIDPName c := happyUpstreamAuthcode s := "4321" return &requestPath{ - upstreamIDPName: &n, - code: &c, - state: &s, + code: &c, + state: &s, } } -func (r *requestPath) WithUpstreamIDPName(name string) *requestPath { - r.upstreamIDPName = &name - return r -} - func (r *requestPath) WithCode(code string) *requestPath { r.code = &code return r @@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath { } func (r *requestPath) String() string { - path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName) + path := "/downstream-provider-name/callback?" params := url.Values{} if r.code != nil { params.Add("code", *r.code) @@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat func happyUpstreamStateParam() *upstreamStateParamBuilder { return &upstreamStateParamBuilder{ + U: happyUpstreamIDPName, P: happyDownstreamRequestParams, N: happyDownstreamNonce, C: happyDownstreamCSRF, From c4ff1ca30448553a4c8c716fb2595048629ca4bc Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Fri, 20 Nov 2020 13:56:35 -0800 Subject: [PATCH 25/57] auth_handler.go: Ignore invalid CSRF cookies rather than return error Generate a new cookie for the user and move on as if they had not sent a bad cookie. Hopefully this will make the user experience better if, for example, the server rotated cookie signing keys and then a user submitted a very old cookie. Signed-off-by: Andrew Keesler --- internal/oidc/auth/auth_handler.go | 17 ++++---- internal/oidc/auth/auth_handler_test.go | 54 ++++++++++++++++--------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 8a566620d..5634fc01c 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -41,11 +41,7 @@ func NewHandler( return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) } - csrfFromCookie, err := readCSRFCookie(r, cookieCodec) - if err != nil { - plog.InfoErr("error reading CSRF cookie", err) - return err - } + csrfFromCookie := readCSRFCookie(r, cookieCodec) authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) if err != nil { @@ -133,20 +129,23 @@ func NewHandler( }) } -func readCSRFCookie(r *http.Request, codec oidc.Codec) (csrftoken.CSRFToken, error) { +func readCSRFCookie(r *http.Request, codec oidc.Codec) csrftoken.CSRFToken { receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found - return "", nil + return "" } var csrfFromCookie csrftoken.CSRFToken err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) if err != nil { - return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) + // We can ignore any errors and just make a new cookie. Hopefully this will + // make the user experience better if, for example, the server rotated + // cookie signing keys and then a user submitted a very old cookie. + return "" } - return csrfFromCookie, nil + return csrfFromCookie } func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 381ff052c..519f5b7ca 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -328,6 +328,26 @@ func TestAuthorizationEndpoint(t *testing.T) { wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, }, + { + name: "error while decoding CSRF cookie just generates a new cookie and succeeds as usual", + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, + method: http.MethodGet, + path: happyGetRequestPath, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + // Generated a new CSRF cookie and set it in the response. + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), + wantUpstreamStateParamInLocationHeader: true, + wantBodyStringWithLocationInHref: true, + }, { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: downstreamIssuer, @@ -639,22 +659,6 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/plain; charset=utf-8", wantBodyString: "Internal Server Error: error generating PKCE param\n", }, - { - name: "error while decoding CSRF cookie", - issuer: downstreamIssuer, - idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), - generateCSRF: happyCSRFGenerator, - generatePKCE: happyPKCEGenerator, - generateNonce: happyNonceGenerator, - stateEncoder: happyStateEncoder, - cookieEncoder: happyCookieEncoder, - method: http.MethodGet, - path: happyGetRequestPath, - csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", - wantStatus: http.StatusUnprocessableEntity, - wantContentType: "text/plain; charset=utf-8", - wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n", - }, { name: "no upstream providers are configured", issuer: downstreamIssuer, @@ -864,10 +868,20 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore require.NoError(t, err) expectedLocationURL, err := url.Parse(expectedURL) require.NoError(t, err) - require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme) - require.Equal(t, expectedLocationURL.User, actualLocationURL.User) - require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host) - require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path) + require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme, + "schemes were not equal: expected %s but got %s", expectedURL, actualURL, + ) + require.Equal(t, expectedLocationURL.User, actualLocationURL.User, + "users were not equal: expected %s but got %s", expectedURL, actualURL, + ) + + require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host, + "hosts were not equal: expected %s but got %s", expectedURL, actualURL, + ) + + require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path, + "paths were not equal: expected %s but got %s", expectedURL, actualURL, + ) expectedLocationQuery := expectedLocationURL.Query() actualLocationQuery := actualLocationURL.Query() From 58a3e35c511856aed5478125d48091a340c7ba72 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Mon, 30 Nov 2020 11:07:25 -0500 Subject: [PATCH 26/57] Revert "test/integration: skip TestSupervisorLogin until new callback logic is on main" This reverts commit eae6d355f8ef35ea446839af76bed7a96b072fbb. We have added the new callback path logic (see b21f003), so we can stop skipping this test. --- test/integration/supervisor_login_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index ca5c2787e..0ce937fbb 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -28,8 +28,6 @@ import ( ) func TestSupervisorLogin(t *testing.T) { - t.Skip("waiting on new callback path logic to get merged in from the callback endpoint work") - env := library.IntegrationEnv(t) client := library.NewSupervisorClientset(t) @@ -117,7 +115,6 @@ func TestSupervisorLogin(t *testing.T) { } } -//nolint:unused func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedirectURI string) string { // We need to construct the downstream issuer path from the upstream redirect URI since the two // are related, and the upstream redirect URI is supplied via a static test environment @@ -145,7 +142,6 @@ func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedire return redirectURIPathWithoutLastSegment } -//nolint:unused func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string { t.Helper() downstreamOAuth2Config := oauth2.Config{ @@ -167,7 +163,6 @@ func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string { ) } -//nolint:unused func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Code) { t.Helper() state, err := state.Generate() @@ -179,7 +174,6 @@ func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Cod return state, nonce, pkce } -//nolint:unused func requireValidRedirectLocation( ctx context.Context, t *testing.T, From d64acbb5a9e20a0100f4f1c3153285f956bb4be8 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 30 Nov 2020 14:54:11 -0600 Subject: [PATCH 27/57] Add upstreamoidc.ProviderConfig type implementing provider.UpstreamOIDCIdentityProviderI. Signed-off-by: Matt Moyer --- .../upstreamwatcher/upstreamwatcher.go | 25 +- .../upstreamwatcher/upstreamwatcher_test.go | 71 ++++-- .../provider/dynamic_upstream_idp_provider.go | 50 +--- internal/upstreamoidc/upstreamoidc.go | 106 +++++++++ internal/upstreamoidc/upstreamoidc_test.go | 223 ++++++++++++++++++ 5 files changed, 391 insertions(+), 84 deletions(-) create mode 100644 internal/upstreamoidc/upstreamoidc.go create mode 100644 internal/upstreamoidc/upstreamoidc_test.go diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go index 9c8952e9b..4955646b1 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go @@ -17,6 +17,7 @@ import ( "github.com/coreos/go-oidc" "github.com/go-logr/logr" + "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" @@ -30,6 +31,7 @@ import ( pinnipedcontroller "go.pinniped.dev/internal/controller" "go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/upstreamoidc" ) const ( @@ -150,10 +152,14 @@ func (c *controller) Sync(ctx controllerlib.Context) error { // validateUpstream validates the provided v1alpha1.UpstreamOIDCProvider and returns the validated configuration as a // provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.UpstreamOIDCProvider. -func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *provider.UpstreamOIDCIdentityProvider { - result := provider.UpstreamOIDCIdentityProvider{ - Name: upstream.Name, - Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes), +func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *upstreamoidc.ProviderConfig { + result := upstreamoidc.ProviderConfig{ + Name: upstream.Name, + Config: &oauth2.Config{ + Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes), + }, + UsernameClaim: upstream.Spec.Claims.Username, + GroupsClaim: upstream.Spec.Claims.Groups, } conditions := []*v1alpha1.Condition{ c.validateSecret(upstream, &result), @@ -180,7 +186,7 @@ func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alp } // validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition. -func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { +func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { secretName := upstream.Spec.Client.SecretName // Fetch the Secret from informer cache. @@ -217,7 +223,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res } // If everything is valid, update the result and set the condition to true. - result.ClientID = string(clientID) + result.Config.ClientID = string(clientID) return &v1alpha1.Condition{ Type: typeClientCredsValid, Status: v1alpha1.ConditionTrue, @@ -227,7 +233,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res } // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. -func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { +func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { // Get the provider (from cache if possible). discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) @@ -258,8 +264,6 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) } - // TODO also parse the token endpoint from the discovery info and put it onto the `result` - // Parse out and validate the discovered authorize endpoint. authURL, err := url.Parse(discoveredProvider.Endpoint().AuthURL) if err != nil { @@ -280,7 +284,8 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst } // If everything is valid, update the result and set the condition to true. - result.AuthorizationURL = *authURL + result.Config.Endpoint = discoveredProvider.Endpoint() + result.Provider = discoveredProvider return &v1alpha1.Condition{ Type: typeOIDCDiscoverySucceeded, Status: v1alpha1.ConditionTrue, diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go index 4891ae8a3..3ecfa91ac 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go @@ -24,9 +24,11 @@ import ( pinnipedfake "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned/fake" pinnipedinformers "go.pinniped.dev/generated/1.19/client/supervisor/informers/externalversions" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/testlogger" + "go.pinniped.dev/internal/upstreamoidc" ) func TestController(t *testing.T) { @@ -49,6 +51,8 @@ func TestController(t *testing.T) { testClientID = "test-oidc-client-id" testClientSecret = "test-oidc-client-secret" testValidSecretData = map[string][]byte{"clientID": []byte(testClientID), "clientSecret": []byte(testClientSecret)} + testGroupsClaim = "test-groups-claim" + testUsernameClaim = "test-username-claim" ) tests := []struct { name string @@ -56,7 +60,7 @@ func TestController(t *testing.T) { inputSecrets []runtime.Object wantErr string wantLogs []string - wantResultingCache []provider.UpstreamOIDCIdentityProvider + wantResultingCache []provider.UpstreamOIDCIdentityProviderI wantResultingUpstreams []v1alpha1.UpstreamOIDCProvider }{ { @@ -80,7 +84,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="secret \"test-client-secret\" not found" "name"="test-name" "namespace"="test-namespace" "reason"="SecretNotFound" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -126,7 +130,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" has wrong type \"some-other-type\" (should be \"secrets.pinniped.dev/oidc-client\")" "name"="test-name" "namespace"="test-namespace" "reason"="SecretWrongType" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -171,7 +175,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" is missing required keys [\"clientID\" \"clientSecret\"]" "name"="test-name" "namespace"="test-namespace" "reason"="SecretMissingKeys" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -219,7 +223,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -267,7 +271,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: no certificates found" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: no certificates found" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -312,7 +316,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to perform OIDC discovery against \"invalid-url\"" "reason"="Unreachable" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to perform OIDC discovery against \"invalid-url\"" "name"="test-name" "namespace"="test-namespace" "reason"="Unreachable" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -358,7 +362,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -404,7 +408,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -437,6 +441,7 @@ func TestController(t *testing.T) { TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, Client: v1alpha1.OIDCClient{SecretName: testSecretName}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: append(testAdditionalScopes, "xyz", "openid")}, + Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim}, }, Status: v1alpha1.UpstreamOIDCProviderStatus{ Phase: "Error", @@ -455,12 +460,16 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ - Name: testName, - ClientID: testClientID, - AuthorizationURL: *testIssuerAuthorizeURL, - Scopes: append(testExpectedScopes, "xyz"), - }}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{ + &oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: testName, + ClientID: testClientID, + AuthorizationURL: *testIssuerAuthorizeURL, + Scopes: append(testExpectedScopes, "xyz"), + UsernameClaim: testUsernameClaim, + GroupsClaim: testGroupsClaim, + }, + }, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -481,6 +490,7 @@ func TestController(t *testing.T) { TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, Client: v1alpha1.OIDCClient{SecretName: testSecretName}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: testAdditionalScopes}, + Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim}, }, Status: v1alpha1.UpstreamOIDCProviderStatus{ Phase: "Ready", @@ -499,12 +509,16 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ - Name: testName, - ClientID: testClientID, - AuthorizationURL: *testIssuerAuthorizeURL, - Scopes: testExpectedScopes, - }}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{ + &oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: testName, + ClientID: testClientID, + AuthorizationURL: *testIssuerAuthorizeURL, + Scopes: testExpectedScopes, + UsernameClaim: testUsernameClaim, + GroupsClaim: testGroupsClaim, + }, + }, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -527,9 +541,9 @@ func TestController(t *testing.T) { kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) testLog := testlogger.New(t) cache := provider.NewDynamicUpstreamIDPProvider() - initialProviderList := make([]provider.UpstreamOIDCIdentityProviderI, 1) - initialProviderList[0] = &provider.UpstreamOIDCIdentityProvider{Name: "initial-entry"} - cache.SetIDPList(initialProviderList) + cache.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{ + &upstreamoidc.ProviderConfig{Name: "initial-entry"}, + }) controller := New( cache, @@ -557,8 +571,13 @@ func TestController(t *testing.T) { actualIDPList := cache.GetIDPList() require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) for i := range actualIDPList { - actualIDP := actualIDPList[i].(*provider.UpstreamOIDCIdentityProvider) - require.Equal(t, tt.wantResultingCache[i], *actualIDP) + actualIDP := actualIDPList[i].(*upstreamoidc.ProviderConfig) + require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName()) + require.Equal(t, tt.wantResultingCache[i].GetClientID(), actualIDP.GetClientID()) + require.Equal(t, tt.wantResultingCache[i].GetAuthorizationURL().String(), actualIDP.GetAuthorizationURL().String()) + require.Equal(t, tt.wantResultingCache[i].GetUsernameClaim(), actualIDP.GetUsernameClaim()) + require.Equal(t, tt.wantResultingCache[i].GetGroupsClaim(), actualIDP.GetGroupsClaim()) + require.ElementsMatch(t, tt.wantResultingCache[i].GetScopes(), actualIDP.GetScopes()) } actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{}) diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 98c86bdb7..50ba17ca0 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -33,12 +33,8 @@ type UpstreamOIDCIdentityProviderI interface { // ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider. GetGroupsClaim() string - AuthcodeExchanger -} - -// Performs upstream OIDC authorization code exchange and token validation. -// Returns the validated raw tokens as well as the parsed claims of the ID token. -type AuthcodeExchanger interface { + // Performs upstream OIDC authorization code exchange and token validation. + // Returns the validated raw tokens as well as the parsed claims of the ID token. ExchangeAuthcodeAndValidateTokens( ctx context.Context, authcode string, @@ -47,48 +43,6 @@ type AuthcodeExchanger interface { ) (tokens oidcclient.Token, parsedIDTokenClaims map[string]interface{}, err error) } -type UpstreamOIDCIdentityProvider struct { - Name string - ClientID string - AuthorizationURL url.URL - UsernameClaim string - GroupsClaim string - Scopes []string -} - -func (u *UpstreamOIDCIdentityProvider) GetName() string { - return u.Name -} - -func (u *UpstreamOIDCIdentityProvider) GetClientID() string { - return u.ClientID -} - -func (u *UpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL { - return &u.AuthorizationURL -} - -func (u *UpstreamOIDCIdentityProvider) GetScopes() []string { - return u.Scopes -} - -func (u *UpstreamOIDCIdentityProvider) GetUsernameClaim() string { - return u.UsernameClaim -} - -func (u *UpstreamOIDCIdentityProvider) GetGroupsClaim() string { - return u.GroupsClaim -} - -func (u *UpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( - ctx context.Context, - authcode string, - pkceCodeVerifier pkce.Code, - expectedIDTokenNonce nonce.Nonce, -) (oidcclient.Token, map[string]interface{}, error) { - panic("TODO implement me") // TODO -} - type DynamicUpstreamIDPProvider interface { SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) GetIDPList() []UpstreamOIDCIdentityProviderI diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go new file mode 100644 index 000000000..13a800b03 --- /dev/null +++ b/internal/upstreamoidc/upstreamoidc.go @@ -0,0 +1,106 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions. +package upstreamoidc + +import ( + "context" + "net/http" + "net/url" + + "github.com/coreos/go-oidc" + "golang.org/x/oauth2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/pkce" +) + +// ProviderConfig holds the active configuration of an upstream OIDC provider. +type ProviderConfig struct { + Name string + UsernameClaim string + GroupsClaim string + Config *oauth2.Config + Provider interface { + Verifier(*oidc.Config) *oidc.IDTokenVerifier + } +} + +// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI. +var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) + +func (p *ProviderConfig) GetName() string { + return p.Name +} + +func (p *ProviderConfig) GetClientID() string { + return p.Config.ClientID +} + +func (p *ProviderConfig) GetAuthorizationURL() *url.URL { + result, _ := url.Parse(p.Config.Endpoint.AuthURL) + return result +} + +func (p *ProviderConfig) GetScopes() []string { + return p.Config.Scopes +} + +func (p *ProviderConfig) GetUsernameClaim() string { + return p.UsernameClaim +} + +func (p *ProviderConfig) GetGroupsClaim() string { + return p.GroupsClaim +} + +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier()) + if err != nil { + return oidcclient.Token{}, nil, err + } + + idTok, hasIDTok := tok.Extra("id_token").(string) + if !hasIDTok { + return oidcclient.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + } + validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok) + if err != nil { + return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + if validated.AccessTokenHash != "" { + if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { + return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + } + if expectedIDTokenNonce != "" { + if err := expectedIDTokenNonce.Validate(validated); err != nil { + return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + } + } + + var validatedClaims map[string]interface{} + if err := validated.Claims(&validatedClaims); err != nil { + return oidcclient.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) + } + + return oidcclient.Token{ + AccessToken: &oidcclient.AccessToken{ + Token: tok.AccessToken, + Type: tok.TokenType, + Expiry: metav1.NewTime(tok.Expiry), + }, + RefreshToken: &oidcclient.RefreshToken{ + Token: tok.RefreshToken, + }, + IDToken: &oidcclient.IDToken{ + Token: idTok, + Expiry: metav1.NewTime(validated.Expiry), + }, + }, validatedClaims, nil +} diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go new file mode 100644 index 000000000..d3cf77ac8 --- /dev/null +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -0,0 +1,223 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package upstreamoidc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coreos/go-oidc" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/nonce" +) + +func TestProviderConfig(t *testing.T) { + t.Run("getters get", func(t *testing.T) { + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + Endpoint: oauth2.Endpoint{AuthURL: "https://example.com"}, + Scopes: []string{"scope1", "scope2"}, + }, + } + require.Equal(t, "test-name", p.GetName()) + require.Equal(t, "test-client-id", p.GetClientID()) + require.Equal(t, "https://example.com", p.GetAuthorizationURL().String()) + require.ElementsMatch(t, []string{"scope1", "scope2"}, p.GetScopes()) + require.Equal(t, "test-username-claim", p.GetUsernameClaim()) + require.Equal(t, "test-groups-claim", p.GetGroupsClaim()) + }) + + const ( + // Test JWTs generated with https://smallstep.com/docs/cli/crypto/jwt/: + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + invalidAccessTokenHashIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg" //nolint: gosec + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + invalidNonceIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ" //nolint: gosec + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"foo": "bar", "bat": "baz"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + validIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImJhdCI6ImJheiIsImZvbyI6ImJhciIsImlhdCI6MTYwNjc2ODU5MywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDY3Njg1OTMsInN1YiI6InRlc3QtdXNlciJ9.DuqVZ7pGhHqKz7gNr4j2W1s1N8YrSltktH4wW19L4oD1OE2-O72jAnNj5xdjilsa8l7h9ox-5sMF0Tkh3BdRlHQK9dEtNm9tW-JreUnWJ3LCqUs-LZp4NG7edvq2sH_1Bn7O2_NQV51s8Pl04F60CndjQ4NM-6WkqDQTKyY6vJXU7idvM-6TM2HJZK-Na88cOJ9KIK37tL5DhcbsHVF47Dq8uPZ0KbjNQjJLAIi_1GeQBgc6yJhDUwRY4Xu6S0dtTHA6xTI8oSXoamt4bkViEHfJBp97LZQiNz8mku5pVc0aNwP1p4hMHxRHhLXrJjbh-Hx4YFjxtOnIq9t1mHlD4A" //nolint: gosec + ) + + tests := []struct { + name string + authCode string + expectNonce nonce.Nonce + returnIDTok string + wantErr string + wantToken oidcclient.Token + wantClaims map[string]interface{} + }{ + { + name: "exchange fails with network error", + authCode: "invalid-auth-code", + wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", + }, + { + name: "missing ID token", + authCode: "valid", + wantErr: "received response missing ID token", + }, + { + name: "invalid ID token", + authCode: "valid", + returnIDTok: "invalid-jwt", + wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + }, + { + name: "invalid access token hash", + authCode: "valid", + returnIDTok: invalidAccessTokenHashIDToken, + wantErr: "received invalid ID token: access token hash does not match value in ID token", + }, + { + name: "invalid nonce", + authCode: "valid", + expectNonce: "test-nonce", + returnIDTok: invalidNonceIDToken, + wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`, + }, + { + name: "invalid nonce but not checked", + authCode: "valid", + expectNonce: "", + returnIDTok: invalidNonceIDToken, + wantToken: oidcclient.Token{ + AccessToken: &oidcclient.AccessToken{ + Token: "test-access-token", + Expiry: metav1.Time{}, + }, + RefreshToken: &oidcclient.RefreshToken{ + Token: "test-refresh-token", + }, + IDToken: &oidcclient.IDToken{ + Token: invalidNonceIDToken, + Expiry: metav1.Time{}, + }, + }, + }, + { + name: "valid", + authCode: "valid", + returnIDTok: validIDToken, + wantToken: oidcclient.Token{ + AccessToken: &oidcclient.AccessToken{ + Token: "test-access-token", + Expiry: metav1.Time{}, + }, + RefreshToken: &oidcclient.RefreshToken{ + Token: "test-refresh-token", + }, + IDToken: &oidcclient.IDToken{ + Token: validIDToken, + Expiry: metav1.Time{}, + }, + }, + wantClaims: map[string]interface{}{ + "foo": "bar", + "bat": "baz", + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, r.ParseForm()) + require.Equal(t, "test-client-id", r.Form.Get("client_id")) + require.Equal(t, "test-pkce", r.Form.Get("code_verifier")) + require.Equal(t, "authorization_code", r.Form.Get("grant_type")) + require.NotEmpty(t, r.Form.Get("code")) + if r.Form.Get("code") != "valid" { + http.Error(w, "invalid authorization code", http.StatusForbidden) + return + } + var response struct { + oauth2.Token + IDToken string `json:"id_token,omitempty"` + } + response.AccessToken = "test-access-token" + response.RefreshToken = "test-refresh-token" + response.Expiry = time.Now().Add(time.Hour) + response.IDToken = tt.returnIDTok + w.Header().Set("content-type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(&response)) + })) + t.Cleanup(tokenServer.Close) + + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + Endpoint: oauth2.Endpoint{ + AuthURL: "https://example.com", + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + Provider: &mockProvider{}, + } + + ctx := context.Background() + + tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Equal(t, oidcclient.Token{}, tok) + require.Nil(t, claims) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantToken, tok) + + for k, v := range tt.wantClaims { + require.Equal(t, v, claims[k]) + } + }) + } +} + +// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. +func mockVerifier() *oidc.IDTokenVerifier { + mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) + mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). + AnyTimes(). + DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) { + jws, err := jose.ParseSigned(jwt) + if err != nil { + return nil, err + } + return jws.UnsafePayloadWithoutVerification(), nil + }) + + return oidc.NewVerifier("", mockKeySet, &oidc.Config{ + SkipIssuerCheck: true, + SkipExpiryCheck: true, + SkipClientIDCheck: true, + }) +} + +type mockProvider struct{} + +func (m *mockProvider) Verifier(_ *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } From d32583dd7f8e866983c2ca23bc989e6afdec3a9d Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 30 Nov 2020 17:02:03 -0600 Subject: [PATCH 28/57] Move OIDC Token structs into a new `oidctypes` package. Signed-off-by: Matt Moyer --- cmd/pinniped/cmd/login_oidc.go | 3 +- cmd/pinniped/cmd/login_oidc_test.go | 7 +- .../oidc/callback/callback_handler_test.go | 6 +- internal/oidc/oidctestutil/oidc.go | 6 +- .../provider/dynamic_upstream_idp_provider.go | 4 +- internal/upstreamoidc/upstreamoidc.go | 24 +++--- internal/upstreamoidc/upstreamoidc_test.go | 22 ++--- pkg/oidcclient/filesession/cachefile.go | 3 +- pkg/oidcclient/filesession/cachefile_test.go | 49 +++++------ pkg/oidcclient/filesession/filesession.go | 7 +- .../filesession/filesession_test.go | 85 ++++++++++--------- pkg/oidcclient/login.go | 34 +++++--- pkg/oidcclient/login_test.go | 43 +++++----- .../{types.go => oidctypes/oidctypes.go} | 24 ++---- 14 files changed, 162 insertions(+), 155 deletions(-) rename pkg/oidcclient/{types.go => oidctypes/oidctypes.go} (69%) diff --git a/cmd/pinniped/cmd/login_oidc.go b/cmd/pinniped/cmd/login_oidc.go index 1677c00f2..c8d006624 100644 --- a/cmd/pinniped/cmd/login_oidc.go +++ b/cmd/pinniped/cmd/login_oidc.go @@ -20,6 +20,7 @@ import ( "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/filesession" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) //nolint: gochecknoinits @@ -27,7 +28,7 @@ func init() { loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login)) } -func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error)) *cobra.Command { +func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error)) *cobra.Command { var ( cmd = cobra.Command{ Args: cobra.NoArgs, diff --git a/cmd/pinniped/cmd/login_oidc_test.go b/cmd/pinniped/cmd/login_oidc_test.go index 3a61934d6..37cfac4e2 100644 --- a/cmd/pinniped/cmd/login_oidc_test.go +++ b/cmd/pinniped/cmd/login_oidc_test.go @@ -13,6 +13,7 @@ import ( "go.pinniped.dev/internal/here" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestLoginOIDCCommand(t *testing.T) { @@ -92,12 +93,12 @@ func TestLoginOIDCCommand(t *testing.T) { gotClientID string gotOptions []oidcclient.Option ) - cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error) { + cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error) { gotIssuer = issuer gotClientID = clientID gotOptions = opts - return &oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + return &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time1), }, diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 69072067c..23656da2b 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -23,8 +23,8 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/testutil" - "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" ) @@ -651,8 +651,8 @@ func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamO UsernameClaim: u.usernameClaim, GroupsClaim: u.groupsClaim, Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, u.idToken, u.authcodeExchangeErr + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr }, } } diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index a9b6acbd2..43a7147fd 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -8,8 +8,8 @@ import ( "net/url" "go.pinniped.dev/internal/oidc/provider" - "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" ) @@ -36,7 +36,7 @@ type TestUpstreamOIDCIdentityProvider struct { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, - ) (oidcclient.Token, map[string]interface{}, error) + ) (oidctypes.Token, map[string]interface{}, error) exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs @@ -71,7 +71,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, -) (oidcclient.Token, map[string]interface{}, error) { +) (oidctypes.Token, map[string]interface{}, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 50ba17ca0..0c08708c4 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -8,8 +8,8 @@ import ( "net/url" "sync" - "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" ) @@ -40,7 +40,7 @@ type UpstreamOIDCIdentityProviderI interface { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, - ) (tokens oidcclient.Token, parsedIDTokenClaims map[string]interface{}, err error) + ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) } type DynamicUpstreamIDPProvider interface { diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 13a800b03..b44f02bc0 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -15,8 +15,8 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc/provider" - "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" ) @@ -59,46 +59,46 @@ func (p *ProviderConfig) GetGroupsClaim() string { return p.GroupsClaim } -func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier()) if err != nil { - return oidcclient.Token{}, nil, err + return oidctypes.Token{}, nil, err } idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { - return oidcclient.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") } validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok) if err != nil { - return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } if validated.AccessTokenHash != "" { if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } } if expectedIDTokenNonce != "" { if err := expectedIDTokenNonce.Validate(validated); err != nil { - return oidcclient.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) } } var validatedClaims map[string]interface{} if err := validated.Claims(&validatedClaims); err != nil { - return oidcclient.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) } - return oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + return oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: tok.AccessToken, Type: tok.TokenType, Expiry: metav1.NewTime(tok.Expiry), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: tok.RefreshToken, }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: idTok, Expiry: metav1.NewTime(validated.Expiry), }, diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index d3cf77ac8..3f3eed2e1 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -19,8 +19,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/mocks/mockkeyset" - "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestProviderConfig(t *testing.T) { @@ -62,7 +62,7 @@ func TestProviderConfig(t *testing.T) { expectNonce nonce.Nonce returnIDTok string wantErr string - wantToken oidcclient.Token + wantToken oidctypes.Token wantClaims map[string]interface{} }{ { @@ -99,15 +99,15 @@ func TestProviderConfig(t *testing.T) { authCode: "valid", expectNonce: "", returnIDTok: invalidNonceIDToken, - wantToken: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + wantToken: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Expiry: metav1.Time{}, }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: invalidNonceIDToken, Expiry: metav1.Time{}, }, @@ -117,15 +117,15 @@ func TestProviderConfig(t *testing.T) { name: "valid", authCode: "valid", returnIDTok: validIDToken, - wantToken: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + wantToken: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Expiry: metav1.Time{}, }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: validIDToken, Expiry: metav1.Time{}, }, @@ -184,7 +184,7 @@ func TestProviderConfig(t *testing.T) { tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce) if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) - require.Equal(t, oidcclient.Token{}, tok) + require.Equal(t, oidctypes.Token{}, tok) require.Nil(t, claims) return } diff --git a/pkg/oidcclient/filesession/cachefile.go b/pkg/oidcclient/filesession/cachefile.go index 3629ca5f6..9ea46bc02 100644 --- a/pkg/oidcclient/filesession/cachefile.go +++ b/pkg/oidcclient/filesession/cachefile.go @@ -17,6 +17,7 @@ import ( "sigs.k8s.io/yaml" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) var ( @@ -48,7 +49,7 @@ type ( Key oidcclient.SessionCacheKey `json:"key"` CreationTimestamp metav1.Time `json:"creationTimestamp"` LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"` - Tokens oidcclient.Token `json:"tokens"` + Tokens oidctypes.Token `json:"tokens"` } ) diff --git a/pkg/oidcclient/filesession/cachefile_test.go b/pkg/oidcclient/filesession/cachefile_test.go index 0ddcdf9bc..a881d7d4f 100644 --- a/pkg/oidcclient/filesession/cachefile_test.go +++ b/pkg/oidcclient/filesession/cachefile_test.go @@ -12,6 +12,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) // validSession should be the same data as `testdata/valid.yaml`. @@ -27,17 +28,17 @@ var validSession = sessionCache{ }, CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()), LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -139,8 +140,8 @@ func TestNormalized(t *testing.T) { // ID token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + Tokens: oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "", Expiry: metav1.NewTime(now.Add(1 * time.Minute)), }, @@ -149,8 +150,8 @@ func TestNormalized(t *testing.T) { // ID token is expired. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + Tokens: oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), }, @@ -159,8 +160,8 @@ func TestNormalized(t *testing.T) { // Access token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "", Expiry: metav1.NewTime(now.Add(1 * time.Minute)), }, @@ -169,8 +170,8 @@ func TestNormalized(t *testing.T) { // Access token is expired. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), }, @@ -179,8 +180,8 @@ func TestNormalized(t *testing.T) { // Refresh token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "", }, }, @@ -188,8 +189,8 @@ func TestNormalized(t *testing.T) { // Session has a refresh token but it hasn't been used in >90 days. { LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -198,8 +199,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token2", }, }, @@ -207,8 +208,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token1", }, }, @@ -222,8 +223,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token1", }, }, @@ -231,8 +232,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token2", }, }, diff --git a/pkg/oidcclient/filesession/filesession.go b/pkg/oidcclient/filesession/filesession.go index 47e0f7614..151fde719 100644 --- a/pkg/oidcclient/filesession/filesession.go +++ b/pkg/oidcclient/filesession/filesession.go @@ -16,6 +16,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) const ( @@ -65,14 +66,14 @@ type Cache struct { } // GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached. -func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token { +func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidctypes.Token { // If the cache file does not exist, exit immediately with no error log if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) { return nil } // Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it. - var result *oidcclient.Token + var result *oidctypes.Token c.withCache(func(cache *sessionCache) { if entry := cache.lookup(key); entry != nil { result = &entry.Tokens @@ -84,7 +85,7 @@ func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token { // PutToken stores the provided token into the session cache under the given parameters. It does not return an error // but may silently fail to update the session cache. -func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidcclient.Token) { +func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidctypes.Token) { // Create the cache directory if it does not exist. if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) { c.errReporter(fmt.Errorf("could not create session cache directory: %w", err)) diff --git a/pkg/oidcclient/filesession/filesession_test.go b/pkg/oidcclient/filesession/filesession_test.go index 1b28e1848..d172af80e 100644 --- a/pkg/oidcclient/filesession/filesession_test.go +++ b/pkg/oidcclient/filesession/filesession_test.go @@ -16,6 +16,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestNew(t *testing.T) { @@ -37,7 +38,7 @@ func TestGetToken(t *testing.T) { trylockFunc func(*testing.T) error unlockFunc func(*testing.T) error key oidcclient.SessionCacheKey - want *oidcclient.Token + want *oidctypes.Token wantErrors []string wantTestFile func(t *testing.T, tmp string) }{ @@ -98,17 +99,17 @@ func TestGetToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -136,17 +137,17 @@ func TestGetToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -160,17 +161,17 @@ func TestGetToken(t *testing.T) { RedirectURI: "http://localhost:0/callback", }, wantErrors: []string{}, - want: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + want: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -218,7 +219,7 @@ func TestPutToken(t *testing.T) { name string makeTestFile func(t *testing.T, tmp string) key oidcclient.SessionCacheKey - token *oidcclient.Token + token *oidctypes.Token wantErrors []string wantTestFile func(t *testing.T, tmp string) }{ @@ -244,17 +245,17 @@ func TestPutToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "old-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "old-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "old-refresh-token", }, }, @@ -268,17 +269,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, @@ -287,17 +288,17 @@ func TestPutToken(t *testing.T) { require.NoError(t, err) require.Len(t, cache.Sessions, 1) require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) - require.Equal(t, oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + require.Equal(t, oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, cache.Sessions[0].Tokens) @@ -316,17 +317,17 @@ func TestPutToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "old-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "old-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "old-refresh-token", }, }, @@ -340,17 +341,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, @@ -359,17 +360,17 @@ func TestPutToken(t *testing.T) { require.NoError(t, err) require.Len(t, cache.Sessions, 2) require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) - require.Equal(t, oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + require.Equal(t, oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, cache.Sessions[1].Tokens) @@ -388,17 +389,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 0898f9449..09d6949f7 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -21,6 +21,7 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" ) @@ -68,7 +69,7 @@ type handlerState struct { } type callbackResult struct { - token *Token + token *oidctypes.Token err error } @@ -116,6 +117,19 @@ func WithBrowserOpen(openURL func(url string) error) Option { } } +// SessionCacheKey contains the data used to select a valid session cache entry. +type SessionCacheKey struct { + Issuer string `json:"issuer"` + ClientID string `json:"clientID"` + Scopes []string `json:"scopes"` + RedirectURI string `json:"redirect_uri"` +} + +type SessionCache interface { + GetToken(SessionCacheKey) *oidctypes.Token + PutToken(SessionCacheKey, *oidctypes.Token) +} + // WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens. func WithSessionCache(cache SessionCache) Option { return func(h *handlerState) error { @@ -135,8 +149,8 @@ func WithClient(httpClient *http.Client) Option { // nopCache is a SessionCache that doesn't actually do anything. type nopCache struct{} -func (*nopCache) GetToken(SessionCacheKey) *Token { return nil } -func (*nopCache) PutToken(SessionCacheKey, *Token) {} +func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } +func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} type discoveryI interface { Endpoint() oauth2.Endpoint @@ -144,7 +158,7 @@ type discoveryI interface { } // Login performs an OAuth2/OIDC authorization code login using a localhost listener. -func Login(issuer string, clientID string, opts ...Option) (*Token, error) { +func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { h := handlerState{ issuer: issuer, clientID: clientID, @@ -274,7 +288,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { } } -func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) { +func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { ctx, cancel := context.WithTimeout(ctx, refreshTimeout) defer cancel() refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) @@ -331,7 +345,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return nil } -func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) { +func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") @@ -350,16 +364,16 @@ func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, che return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) } } - return &Token{ - AccessToken: &AccessToken{ + return &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: tok.AccessToken, Type: tok.TokenType, Expiry: metav1.NewTime(tok.Expiry), }, - RefreshToken: &RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: tok.RefreshToken, }, - IDToken: &IDToken{ + IDToken: &oidctypes.IDToken{ Token: idTok, Expiry: metav1.NewTime(validated.Expiry), }, diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 2b13752b5..5bff0142b 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -25,6 +25,7 @@ import ( "go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" ) @@ -32,19 +33,19 @@ import ( // mockSessionCache exists to avoid an import cycle if we generate mocks into another package. type mockSessionCache struct { t *testing.T - getReturnsToken *Token + getReturnsToken *oidctypes.Token sawGetKeys []SessionCacheKey sawPutKeys []SessionCacheKey - sawPutTokens []*Token + sawPutTokens []*oidctypes.Token } -func (m *mockSessionCache) GetToken(key SessionCacheKey) *Token { +func (m *mockSessionCache) GetToken(key SessionCacheKey) *oidctypes.Token { m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID) m.sawGetKeys = append(m.sawGetKeys, key) return m.getReturnsToken } -func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) { +func (m *mockSessionCache) PutToken(key SessionCacheKey, token *oidctypes.Token) { m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token) m.sawPutKeys = append(m.sawPutKeys, key) m.sawPutTokens = append(m.sawPutTokens, token) @@ -55,15 +56,15 @@ func TestLogin(t *testing.T) { time1Unix := int64(2075807775) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) - testToken := Token{ - AccessToken: &AccessToken{ + testToken := oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), }, - RefreshToken: &RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, - IDToken: &IDToken{ + IDToken: &oidctypes.IDToken{ // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above): // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775 Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA", @@ -145,7 +146,7 @@ func TestLogin(t *testing.T) { issuer string clientID string wantErr string - wantToken *Token + wantToken *oidctypes.Token }{ { name: "option error", @@ -192,8 +193,8 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity }, @@ -247,12 +248,12 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, }} t.Cleanup(func() { cacheKey := SessionCacheKey{ @@ -284,12 +285,12 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, }} t.Cleanup(func() { require.Empty(t, cache.sawPutKeys) @@ -314,12 +315,12 @@ func TestLogin(t *testing.T) { clientID: "not-the-test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, }} t.Cleanup(func() { require.Empty(t, cache.sawPutKeys) @@ -414,7 +415,7 @@ func TestLogin(t *testing.T) { t.Cleanup(func() { require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) - require.Equal(t, []*Token{&testToken}, cache.sawPutTokens) + require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens) }) require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h)) diff --git a/pkg/oidcclient/types.go b/pkg/oidcclient/oidctypes/oidctypes.go similarity index 69% rename from pkg/oidcclient/types.go rename to pkg/oidcclient/oidctypes/oidctypes.go index 7fbf3a3f0..94f5dcc93 100644 --- a/pkg/oidcclient/types.go +++ b/pkg/oidcclient/oidctypes/oidctypes.go @@ -1,11 +1,10 @@ // Copyright 2020 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package oidcclient +// Package oidctypes provides core data types for OIDC token structures. +package oidctypes -import ( - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" // AccessToken is an OAuth2 access token. type AccessToken struct { @@ -16,7 +15,7 @@ type AccessToken struct { Type string `json:"type,omitempty"` // Expiry is the optional expiration time of the access token. - Expiry metav1.Time `json:"expiryTimestamp,omitempty"` + Expiry v1.Time `json:"expiryTimestamp,omitempty"` } // RefreshToken is an OAuth2 refresh token. @@ -31,7 +30,7 @@ type IDToken struct { Token string `json:"token"` // Expiry is the optional expiration time of the ID token. - Expiry metav1.Time `json:"expiryTimestamp,omitempty"` + Expiry v1.Time `json:"expiryTimestamp,omitempty"` } // Token contains the elements of an OIDC session. @@ -47,16 +46,3 @@ type Token struct { // IDToken is an OpenID Connect ID token. IDToken *IDToken `json:"id,omitempty"` } - -// SessionCacheKey contains the data used to select a valid session cache entry. -type SessionCacheKey struct { - Issuer string `json:"issuer"` - ClientID string `json:"clientID"` - Scopes []string `json:"scopes"` - RedirectURI string `json:"redirect_uri"` -} - -type SessionCache interface { - GetToken(SessionCacheKey) *Token - PutToken(SessionCacheKey, *Token) -} From 25ee99f93a662f87e38f8e8e13813ea606cae6dc Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 30 Nov 2020 17:08:27 -0600 Subject: [PATCH 29/57] Add ValidateToken method to UpstreamOIDCIdentityProviderI interface. Signed-off-by: Matt Moyer --- internal/oidc/oidctestutil/oidc.go | 6 ++++++ internal/oidc/provider/dynamic_upstream_idp_provider.go | 4 ++++ internal/upstreamoidc/upstreamoidc.go | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index 43a7147fd..eafd567f5 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -7,6 +7,8 @@ import ( "context" "net/url" + "golang.org/x/oauth2" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" @@ -96,6 +98,10 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs return u.exchangeAuthcodeAndValidateTokensArgs[call] } +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + panic("implement me") +} + func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { idpProvider := provider.NewDynamicUpstreamIDPProvider() upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 0c08708c4..8ef1e5dbb 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -8,6 +8,8 @@ import ( "net/url" "sync" + "golang.org/x/oauth2" + "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -41,6 +43,8 @@ type UpstreamOIDCIdentityProviderI interface { pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) + + ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) } type DynamicUpstreamIDPProvider interface { diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index b44f02bc0..72de2e33b 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -65,6 +65,10 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, return oidctypes.Token{}, nil, err } + return p.ValidateToken(ctx, tok, expectedIDTokenNonce) +} + +func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") From 4b60c922ef0ffd6184912a251a4bd726e42cf933 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 30 Nov 2020 17:09:01 -0600 Subject: [PATCH 30/57] Add generated mock of UpstreamOIDCIdentityProviderI. Signed-off-by: Matt Moyer --- .../generate.go | 6 + .../mockupstreamoidcidentityprovider.go | 159 ++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 internal/mocks/mockupstreamoidcidentityprovider/generate.go create mode 100644 internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go diff --git a/internal/mocks/mockupstreamoidcidentityprovider/generate.go b/internal/mocks/mockupstreamoidcidentityprovider/generate.go new file mode 100644 index 000000000..cb9c46df5 --- /dev/null +++ b/internal/mocks/mockupstreamoidcidentityprovider/generate.go @@ -0,0 +1,6 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mockupstreamoidcidentityprovider + +//go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider UpstreamOIDCIdentityProviderI diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go new file mode 100644 index 000000000..e3887b827 --- /dev/null +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -0,0 +1,159 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: go.pinniped.dev/internal/oidc/provider (interfaces: UpstreamOIDCIdentityProviderI) + +// Package mockupstreamoidcidentityprovider is a generated GoMock package. +package mockupstreamoidcidentityprovider + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + nonce "go.pinniped.dev/pkg/oidcclient/nonce" + oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" + pkce "go.pinniped.dev/pkg/oidcclient/pkce" + oauth2 "golang.org/x/oauth2" + url "net/url" + reflect "reflect" +) + +// MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface +type MockUpstreamOIDCIdentityProviderI struct { + ctrl *gomock.Controller + recorder *MockUpstreamOIDCIdentityProviderIMockRecorder +} + +// MockUpstreamOIDCIdentityProviderIMockRecorder is the mock recorder for MockUpstreamOIDCIdentityProviderI +type MockUpstreamOIDCIdentityProviderIMockRecorder struct { + mock *MockUpstreamOIDCIdentityProviderI +} + +// NewMockUpstreamOIDCIdentityProviderI creates a new mock instance +func NewMockUpstreamOIDCIdentityProviderI(ctrl *gomock.Controller) *MockUpstreamOIDCIdentityProviderI { + mock := &MockUpstreamOIDCIdentityProviderI{ctrl: ctrl} + mock.recorder = &MockUpstreamOIDCIdentityProviderIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityProviderIMockRecorder { + return m.recorder +} + +// ExchangeAuthcodeAndValidateTokens mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(oidctypes.Token) + ret1, _ := ret[1].(map[string]interface{}) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3) +} + +// GetAuthorizationURL mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetAuthorizationURL() *url.URL { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizationURL") + ret0, _ := ret[0].(*url.URL) + return ret0 +} + +// GetAuthorizationURL indicates an expected call of GetAuthorizationURL +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetAuthorizationURL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationURL", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetAuthorizationURL)) +} + +// GetClientID mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetClientID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetClientID indicates an expected call of GetClientID +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetClientID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientID", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetClientID)) +} + +// GetGroupsClaim mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetGroupsClaim() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupsClaim") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetGroupsClaim indicates an expected call of GetGroupsClaim +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetGroupsClaim() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetGroupsClaim)) +} + +// GetName mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetName indicates an expected call of GetName +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetName)) +} + +// GetScopes mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetScopes() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetScopes") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetScopes indicates an expected call of GetScopes +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetScopes)) +} + +// GetUsernameClaim mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetUsernameClaim() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUsernameClaim") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetUsernameClaim indicates an expected call of GetUsernameClaim +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsernameClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetUsernameClaim)) +} + +// ValidateToken mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2) + ret0, _ := ret[0].(oidctypes.Token) + ret1, _ := ret[1].(map[string]interface{}) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ValidateToken indicates an expected call of ValidateToken +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateToken), arg0, arg1, arg2) +} From b272b3f33149892df2570e1b1f9a863cf15be3fe Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 30 Nov 2020 17:14:57 -0600 Subject: [PATCH 31/57] Refactor oidcclient.Login to use new upstreamoidc package. Signed-off-by: Matt Moyer --- internal/upstreamoidc/upstreamoidc.go | 7 +- pkg/oidcclient/login.go | 83 +++-------- pkg/oidcclient/login_test.go | 205 ++++++++++---------------- 3 files changed, 98 insertions(+), 197 deletions(-) diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 72de2e33b..2957e9e30 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -20,6 +20,10 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) +func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + return &ProviderConfig{Config: config, Provider: provider} +} + // ProviderConfig holds the active configuration of an upstream OIDC provider. type ProviderConfig struct { Name string @@ -31,9 +35,6 @@ type ProviderConfig struct { } } -// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI. -var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) - func (p *ProviderConfig) GetName() string { return p.Name } diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 09d6949f7..2e286efa3 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -16,10 +16,11 @@ import ( "github.com/coreos/go-oidc" "github.com/pkg/browser" "golang.org/x/oauth2" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -52,18 +53,18 @@ type handlerState struct { callbackPath string // Generated parameters of a login flow. - idTokenVerifier *oidc.IDTokenVerifier - oauth2Config *oauth2.Config - state state.State - nonce nonce.Nonce - pkce pkce.Code + provider *oidc.Provider + oauth2Config *oauth2.Config + state state.State + nonce nonce.Nonce + pkce pkce.Code // External calls for things. generateState func() (state.State, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error - oidcDiscover func(context.Context, string) (discoveryI, error) + getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI callbacks chan callbackResult } @@ -152,11 +153,6 @@ type nopCache struct{} func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} -type discoveryI interface { - Endpoint() oauth2.Endpoint - Verifier(*oidc.Config) *oidc.IDTokenVerifier -} - // Login performs an OAuth2/OIDC authorization code login using a localhost listener. func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { h := handlerState{ @@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, - oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { - return oidc.NewProvider(ctx, iss) - }, + getProvider: upstreamoidc.New, } for _, opt := range opts { if err := opt(&h); err != nil { @@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er } // Perform OIDC discovery. - discovered, err := h.oidcDiscover(h.ctx, h.issuer) + h.provider, err = oidc.NewProvider(h.ctx, h.issuer) if err != nil { return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) } - h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID}) // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. h.oauth2Config = &oauth2.Config{ ClientID: h.clientID, - Endpoint: discovered.Endpoint(), + Endpoint: h.provider.Endpoint(), Scopes: h.scopes, } @@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return h.validateToken(ctx, refreshed, false) + token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "") + if err != nil { + return nil, err + } + return &token, nil } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { @@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) } - // Exchange the authorization code for access, ID, and refresh tokens. - oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) + // Exchange the authorization code for access, ID, and refresh tokens and perform required + // validations on the returned ID token. + token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } - // Perform required validations on the returned ID token. - token, err := h.validateToken(r.Context(), oauth2Tok, true) - if err != nil { - return err - } - - h.callbacks <- callbackResult{token: token} + h.callbacks <- callbackResult{token: &token} _, _ = w.Write([]byte("you have been logged in and may now close this tab")) return nil } -func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) { - idTok, hasIDTok := tok.Extra("id_token").(string) - if !hasIDTok { - return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") - } - validated, err := h.idTokenVerifier.Verify(ctx, idTok) - if err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - if validated.AccessTokenHash != "" { - if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - } - if checkNonce { - if err := h.nonce.Validate(validated); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) - } - } - return &oidctypes.Token{ - AccessToken: &oidctypes.AccessToken{ - Token: tok.AccessToken, - Type: tok.TokenType, - Expiry: metav1.NewTime(tok.Expiry), - }, - RefreshToken: &oidctypes.RefreshToken{ - Token: tok.RefreshToken, - }, - IDToken: &oidctypes.IDToken{ - Token: idTok, - Expiry: metav1.NewTime(validated.Expiry), - }, - }, nil -} - func (h *handlerState) serve(listener net.Listener) func() { mux := http.NewServeMux() mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 5bff0142b..280dfd0ad 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -18,11 +18,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" - "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" @@ -57,19 +57,9 @@ func TestLogin(t *testing.T) { require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) testToken := oidctypes.Token{ - AccessToken: &oidctypes.AccessToken{ - Token: "test-access-token", - Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), - }, - RefreshToken: &oidctypes.RefreshToken{ - Token: "test-refresh-token", - }, - IDToken: &oidctypes.IDToken{ - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above): - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775 - Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA", - Expiry: metav1.NewTime(time1.Add(2 * time.Minute)), - }, + AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, + IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))}, } // Start a test server that returns 500 errors @@ -78,7 +68,7 @@ func TestLogin(t *testing.T) { })) t.Cleanup(errorServer.Close) - // Start a test server that returns a real keyset and answers refresh requests. + // Start a test server that returns a real discovery document and answers refresh requests. providerMux := http.NewServeMux() successServer := httptest.NewServer(providerMux) t.Cleanup(successServer.Close) @@ -248,6 +238,14 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(testToken, nil, nil) + return mock + } + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", @@ -268,12 +266,6 @@ func TestLogin(t *testing.T) { require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) }) h.cache = cache - - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } return nil } }, @@ -285,6 +277,14 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error")) + return mock + } + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", @@ -298,16 +298,10 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - return nil } }, - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + wantErr: "some validation error", }, { name: "session cache hit but refresh fails", @@ -328,12 +322,6 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - h.listenAddr = "invalid-listen-address" return nil @@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { name string method string query string - returnIDTok string + opt func(t *testing.T) Option wantErr string wantHTTPStatus int }{ @@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) { { name: "invalid code", query: "state=test-state&code=invalid", - wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", + wantErr: "could not complete code exchange: some exchange error", wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "missing ID token", - query: "state=test-state&code=valid", - returnIDTok: "", - wantErr: "received response missing ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid ID token", - query: "state=test-state&code=valid", - returnIDTok: "invalid-jwt", - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid access token hash", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg", - - wantErr: "received invalid ID token: access token hash does not match value in ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid nonce", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ", - - wantHTTPStatus: http.StatusBadRequest, - wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) + return mock + } + return nil + } + }, }, { name: "valid", query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) + return mock + } + return nil + } + }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.NoError(t, r.ParseForm()) - require.Equal(t, "test-client-id", r.Form.Get("client_id")) - require.Equal(t, "test-pkce", r.Form.Get("code_verifier")) - require.Equal(t, "authorization_code", r.Form.Get("grant_type")) - require.NotEmpty(t, r.Form.Get("code")) - if r.Form.Get("code") != "valid" { - http.Error(w, "invalid authorization code", http.StatusForbidden) - return - } - var response struct { - oauth2.Token - IDToken string `json:"id_token,omitempty"` - } - response.AccessToken = "test-access-token" - response.Expiry = time.Now().Add(time.Hour) - response.IDToken = tt.returnIDTok - w.Header().Set("content-type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(&response)) - })) - t.Cleanup(tokenServer.Close) - h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), pkce: pkce.Code("test-pkce"), nonce: nonce.Nonce("test-nonce"), - oauth2Config: &oauth2.Config{ - ClientID: "test-client-id", - RedirectURL: "http://localhost:12345/callback", - Endpoint: oauth2.Endpoint{ - TokenURL: tokenServer.URL, - AuthStyle: oauth2.AuthStyleInParams, - }, - }, - idTokenVerifier: mockVerifier(), + } + if tt.opt != nil { + require.NoError(t, tt.opt(t)(h)) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) { } require.NoError(t, result.err) require.NotNil(t, result.token) - require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) + require.Equal(t, result.token.IDToken.Token, "test-id-token") } }) } } -// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. -func mockVerifier() *oidc.IDTokenVerifier { - mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) - mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). - AnyTimes(). - DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) { - jws, err := jose.ParseSigned(jwt) - if err != nil { - return nil, err - } - return jws.UnsafePayloadWithoutVerification(), nil - }) - - return oidc.NewVerifier("", mockKeySet, &oidc.Config{ - SkipIssuerCheck: true, - SkipExpiryCheck: true, - SkipClientIDCheck: true, - }) +func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl) } -type mockDiscovery struct{ provider *oidc.Provider } +// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token. +type hasAccessTokenMatcher struct{ expected string } -func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } +func (m hasAccessTokenMatcher) Matches(arg interface{}) bool { + return arg.(*oauth2.Token).AccessToken == m.expected +} -func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } +func (m hasAccessTokenMatcher) Got(got interface{}) string { + return got.(*oauth2.Token).AccessToken +} + +func (m hasAccessTokenMatcher) String() string { + return m.expected +} + +func HasAccessToken(expected string) gomock.Matcher { + return hasAccessTokenMatcher{expected: expected} +} From c8eaa3f383ff7fca98772f2f35cdd4ebc345ccdf Mon Sep 17 00:00:00 2001 From: Margo Crawford Date: Tue, 1 Dec 2020 11:01:23 -0800 Subject: [PATCH 32/57] WIP towards using k8s fosite storage in the supervisor's callback endpoint - Note that this WIP commit includes a failing unit test, which will be addressed in the next commit Signed-off-by: Ryan Richard --- internal/crud/crud_test.go | 6 +- .../authorizationcode/authorizationcode.go | 223 +++++++----------- .../authorizationcode_test.go | 115 ++++----- internal/fositestorage/pkce/pkce.go | 115 +++++++++ internal/fositestorage/pkce/pkce_test.go | 100 ++++++++ .../oidc/callback/callback_handler_test.go | 105 +++++---- internal/oidc/kube_storage.go | 110 +++++++++ internal/oidc/nullstorage.go | 30 +-- test/integration/storage_test.go | 4 +- 9 files changed, 548 insertions(+), 260 deletions(-) rename internal/{fosite => fositestorage}/authorizationcode/authorizationcode.go (57%) rename internal/{fosite => fositestorage}/authorizationcode/authorizationcode_test.go (63%) create mode 100644 internal/fositestorage/pkce/pkce.go create mode 100644 internal/fositestorage/pkce/pkce_test.go create mode 100644 internal/oidc/kube_storage.go diff --git a/internal/crud/crud_test.go b/internal/crud/crud_test.go index 93ee98182..cb0fc1476 100644 --- a/internal/crud/crud_test.go +++ b/internal/crud/crud_test.go @@ -62,17 +62,17 @@ func TestStorage(t *testing.T) { }{ { name: "get non-existent", - resource: "authorization-codes", + resource: "authcode", mocks: nil, run: func(t *testing.T, storage Storage) error { _, err := storage.Get(ctx, "not-exists", nil) return err }, wantActions: []coretesting.Action{ - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-t2fx46yyvs3a"), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-t2fx46yyvs3a"), }, wantSecrets: nil, - wantErr: `failed to get authorization-codes for signature not-exists: secrets "pinniped-storage-authorization-codes-t2fx46yyvs3a" not found`, + wantErr: `failed to get authcode for signature not-exists: secrets "pinniped-storage-authcode-t2fx46yyvs3a" not found`, }, { name: "delete non-existent", diff --git a/internal/fosite/authorizationcode/authorizationcode.go b/internal/fositestorage/authorizationcode/authorizationcode.go similarity index 57% rename from internal/fosite/authorizationcode/authorizationcode.go rename to internal/fositestorage/authorizationcode/authorizationcode.go index 917c5cc0f..e41059b35 100644 --- a/internal/fosite/authorizationcode/authorizationcode.go +++ b/internal/fositestorage/authorizationcode/authorizationcode.go @@ -19,7 +19,7 @@ import ( ) const ( - ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.AuthorizeRequest") + ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.Request") ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil") ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version") @@ -33,26 +33,25 @@ type authorizeCodeStorage struct { } type AuthorizeCodeSession struct { - Active bool `json:"active"` - Request *fosite.AuthorizeRequest `json:"request"` - Version string `json:"version"` + Active bool `json:"active"` + Request *fosite.Request `json:"request"` + Version string `json:"version"` } func New(secrets corev1client.SecretInterface) oauth2.AuthorizeCodeStorage { - return &authorizeCodeStorage{storage: crud.New("authorization-codes", secrets)} + return &authorizeCodeStorage{storage: crud.New("authcode", secrets)} } func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { - // this conversion assumes that we do not wrap the default type in any way + // This conversion assumes that we do not wrap the default type in any way // i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation // note that because this type is serialized and stored in Kube, we cannot easily change the implementation later - // TODO hydra uses the fosite.Request struct and ignores the extra fields in fosite.AuthorizeRequest request, err := validateAndExtractAuthorizeRequest(requester) if err != nil { return err } - // TODO hydra stores specific fields from the requester + // Note, in case it is helpful, that Hydra stores specific fields from the requester: // request ID // requestedAt // OAuth client ID @@ -70,12 +69,11 @@ func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, s } func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { - // TODO hydra uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes - - // TODO hydra gets the client from its DB as a concrete type via client ID, - // the hydra memory client just validates that the client ID exists - - // TODO hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length + // Note, in case it is helpful, that Hydra: + // - uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes + // - gets the client from its DB as a concrete type via client ID, the hydra memory client just validates that the + // client ID exists + // - hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length session, _, err := a.getSession(ctx, signature) @@ -88,8 +86,6 @@ func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, sign } func (a *authorizeCodeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error { - // TODO write garbage collector for these codes - session, rv, err := a.getSession(ctx, signature) if err != nil { return err @@ -137,17 +133,15 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string) func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession { return &AuthorizeCodeSession{ - Request: &fosite.AuthorizeRequest{ - Request: fosite.Request{ - Client: &fosite.DefaultOpenIDConnectClient{}, - Session: &openid.DefaultSession{}, - }, + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, }, } } -func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.AuthorizeRequest, error) { - request, ok1 := requester.(*fosite.AuthorizeRequest) +func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { + request, ok1 := requester.(*fosite.Request) if !ok1 { return nil, ErrInvalidAuthorizeRequestType } @@ -189,59 +183,37 @@ func (e *errSerializationFailureWithCause) Error() string { const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ "active": true, "request": { - "responseTypes": [ - "¥Îʒ襧.ɕ7崛瀇莒AȒ[ɠ牐7#$ɭ", - ".5ȿELj9ûF済(D疻翋膗", - "螤Yɫüeɯ紤邥翔勋\\RBʒ;-" - ], - "redirectUri": { - "Scheme": "ħesƻU赒M喦_ģ", - "Opaque": "Ġ/_章Ņ缘T蝟NJ儱礹燃ɢ", - "User": {}, - "Host": "ȳ4螘Wo", - "Path": "}i{", - "RawPath": "5Dža丝eF0eė鱊hǒx蔼Q", - "ForceQuery": true, - "RawQuery": "熤1bbWV", - "Fragment": "ȋc剠鏯ɽÿ¸", - "RawFragment": "qƤ" - }, - "state": "@n,x竘Şǥ嗾稀'ã击漰怼禝穞梠Ǫs", - "handledResponseTypes": [ - "m\"e尚鬞ƻɼ抹d誉y鿜Ķ" - ], - "id": "ō澩ć|3U2Ǜl霨ǦǵpƉ", - "requestedAt": "1989-11-05T22:02:31.105295894Z", + "id": "嫎l蟲aƖ啘艿", + "requestedAt": "2082-11-10T18:36:11.627253638Z", "client": { - "id": "[:c顎疻紵D", - "client_secret": "mQ==", + "id": "!ſɄĈp[述齛ʘUȻ.5ȿE", + "client_secret": "UQ==", "redirect_uris": [ - "恣S@T嵇LJV,Æ櫔袆鋹奘菲", - "ãƻʚ肈ą8O+a駣Ʉɼk瘸'鴵y" + "ǣ珑 ʑ飶畛Ȳ螤Yɫüeɯ紤邥翔勋\\", + "Bʒ;", + "鿃攴Ųęʍ鎾ʦ©cÏN,Ġ/_" ], "grant_types": [ - ".湆ê\"唐", - "曎餄FxD溪躲珫ÈşɜȨû臓嬣\"ǃŤz" + "憉sHĒ尥窘挼Ŀʼn" ], "response_types": [ - "Ņʘʟ車sʊ儓JǐŪɺǣy|耑ʄ" + "4", + "ʄÔ@}i{絧遗Ū^ȝĸ谋Vʋ鱴閇T" ], "scopes": [ - "Ą", - "萙Į(潶饏熞ĝƌĆ1", - "əȤ4Į筦p煖鵄$睱奐耡q" + "R鴝順諲ŮŚ节ȭŀȋc剠鏯ɽÿ¸" ], "audience": [ - "Ʃǣ鿫/Ò敫ƤV" + "Ƥ" ], "public": true, - "jwks_uri": "ȩđ[嬧鱒Ȁ彆媚杨嶒ĤG", + "jwks_uri": "BA瘪囷ɫCʄɢ雐譄uée'", "jwks": { "keys": [ { "kty": "OKP", "crv": "Ed25519", - "x": "JmA-6KpjzqKu0lq9OiB6ORL4s2UzBFPsE1hm6vESeXM", + "x": "nK9xgX_iN7u3u_i8YOO7ZRT_WK028Vd_nhtsUu7Eo6E", "x5u": { "Scheme": "", "Opaque": "", @@ -258,24 +230,7 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ { "kty": "OKP", "crv": "Ed25519", - "x": "LbRC1_3HEe5o7Japk9jFp3_7Ou7Gi2gpqrVrIi0eLDQ", - "x5u": { - "Scheme": "", - "Opaque": "", - "User": null, - "Host": "", - "Path": "", - "RawPath": "", - "ForceQuery": false, - "RawQuery": "", - "Fragment": "", - "RawFragment": "" - } - }, - { - "kty": "OKP", - "crv": "Ed25519", - "x": "Ovk4DF8Yn3mkULuTqnlGJxFnKGu9EL6Xcf2Nql9lK3c", + "x": "UbbswQgzWhfGCRlwQmMp6fw_HoIoqkIaKT-2XN2fuYU", "x5u": { "Scheme": "", "Opaque": "", @@ -291,91 +246,95 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ } ] }, - "token_endpoint_auth_method": "\u0026(K鵢Kj ŏ9Q韉Ķ%嶑輫ǘ(", + "token_endpoint_auth_method": "ŚǗƳȕ暭Q0ņP羾,塐", "request_uris": [ - ":", - "6ě#嫀^xz Ū胧r" + "lj翻LH^俤µDzɹ@©|\u003eɃ", + "[:c顎疻紵D" ], - "request_object_signing_alg": "^¡!犃ĹĐJí¿ō擫ų懫砰¿", - "token_endpoint_auth_signing_alg": "ƈŮå" + "request_object_signing_alg": "m1Ì恣S@T嵇LJV,Æ櫔袆鋹奘", + "token_endpoint_auth_signing_alg": "Fãƻʚ肈ą8O+a駣" }, "scopes": [ - "阃.Ù頀ʌGa皶竇瞍涘¹", - "ȽŮ切衖庀ŰŒ矠", - "楓)馻řĝǕ菸Tĕ1伞柲\u003c\"ʗȆ\\雤" + "ɼk瘸'鴵yſǮٱ\u003eFA曎餄FxD溪", + "綻N镪p赌h%桙dĽ" ], "grantedScopes": [ - "ơ鮫R嫁ɍUƞ9+u!Ȱ", - "}Ă岜" + "癗E]Ņʘʟ車s" ], "form": { - "旸Ť/Õ薝隧;綡,鼞纂=": [ - "[滮]憀", - "3\u003eÙœ蓄UK嗤眇疟Țƒ1v¸KĶ" + "蹬器ķ8ŷ萒寎廭#疶昄Ą-Ƃƞ轵": [ + "熞ĝƌĆ1ȇyǴ濎=Tʉȼʁŀ\u003c", + "耡q戨稞R÷mȵg釽[ƞ@", + "đ[嬧鱒Ȁ彆媚杨嶒ĤGÀ吧Lŷ" + ], + "餟": [ + "蒍z\u0026(K鵢Kj ŏ9Q韉Ķ%", + "輫ǘ(¨Ƞ亱6ě#嫀^xz ", + "@耢ɝ^¡!犃ĹĐJí¿ō擫" ] }, "session": { "Claims": { - "JTI": "};Ų斻遟a衪荖舃", - "Issuer": "芠顋敀拲h蝺$!", - "Subject": "}j%(=ſ氆]垲莲顇", + "JTI": "懫砰¿C筽娴ƓaPu镈賆ŗɰ", + "Issuer": "皶竇瞍涘¹焕iǢǽɽĺŧ", + "Subject": "矠M6ɡǜg炾ʙ$%o6肿Ȫ", "Audience": [ - "彑V\\廳蟕Țǡ蔯ʠ浵Ī龉磈螖畭5", - "渇Ȯʕc" + "ƌÙ鯆GQơ鮫R嫁ɍUƞ9+u!Ȱ踾$" ], - "Nonce": "Ǖ=rlƆ褡{ǏS", - "ExpiresAt": "1975-11-17T14:21:34.205609651Z", - "IssuedAt": "2104-07-03T15:40:03.66710966Z", - "RequestedAt": "2031-05-18T05:14:19.449350555Z", - "AuthTime": "2018-01-27T07:55:06.056862114Z", - "AccessTokenHash": "鹰肁躧", - "AuthenticationContextClassReference": "}Ɇ", - "AuthenticationMethodsReference": "DQh:uȣ", - "CodeHash": "ɘȏıȒ諃龟", + "Nonce": "us旸Ť/Õ薝隧;綡,鼞", + "ExpiresAt": "2065-11-30T13:47:03.613000626Z", + "IssuedAt": "1976-02-22T09:57:20.479850437Z", + "RequestedAt": "2016-04-13T04:18:53.648949323Z", + "AuthTime": "2098-07-12T04:38:54.034043015Z", + "AccessTokenHash": "滮]", + "AuthenticationContextClassReference": "°3\u003eÙ", + "AuthenticationMethodsReference": "k?µ鱔ǤÂ", + "CodeHash": "Țƒ1v¸KĶ跭};", "Extra": { - "a": { - "^i臏f恡ƨ彮": { - "DĘ敨ýÏʥZq7烱藌\\": null, - "V": { - "őŧQĝ微'X焌襱ǭɕņ殥!_n": false - } - }, - "Ż猁": [ - 1706822246 - ] + "=ſ氆": { + "Ƿī,廖ʡ彑V\\廳蟕Ț": [ + 843216989 + ], + "蔯ʠ浵Ī": { + "H\"nǕ=rlƆ褡{ǏSȳŅ": { + "Žg": false + }, + "枱鰧ɛ鸁A渇": null + } }, - "Ò椪)ɫqň2搞Ŀ高摠鲒鿮禗O": 1233332227 + "斻遟a衪荖舃9闄岈锘肺ńʥƕU}j%": 2520197933 } }, "Headers": { "Extra": { - "?戋璖$9\u0026": { - "µcɕ餦ÑEǰ哤癨浦浏1R": [ - 3761201123 - ], - "頓ć§蚲6rǦ\u003cqċ": { - "Łʀ§ȏœɽDz斡冭ȸěaʜD捛?½ʀ+": null, - "ɒúIJ誠ƉyÖ.峷1藍殙菥趏": { - "jHȬȆ#)\u003cX": true + "熒ɘȏıȒ諃龟ŴŠ'耐Ƭ扵ƹ玄ɕwL": { + "ýÏʥZq7烱藌\\捀¿őŧQ": { + "微'X焌襱ǭɕņ殥!_": null, + "荇届UȚ?戋璖$9\u00269舋": { + "ɕ餦ÑEǰ哤癨浦浏1Rk頓ć§蚲6": true } - } + }, + "鲒鿮禗O暒aJP鐜?ĮV嫎h譭ȉ]DĘ": [ + 954647573 + ] }, - "U": 1354158262 + "皩Ƭ}Ɇ.雬Ɨ´唁": 1572524915 } }, "ExpiresAt": { - "\"嘬ȹĹaó剺撱Ȱ": "1985-09-09T04:35:40.533197189Z", - "ʆ\u003e": "1998-08-07T05:37:11.759718906Z", - "柏ʒ鴙*鸆偡Ȓ肯Ûx": "2036-12-19T06:36:14.414805124Z" + "\u003cqċ譈8ŪɎP绿MÅ": "2031-10-18T22:07:34.950803105Z", + "ȸěaʜD捛?½ʀ+Ċ偢镳ʬÍɷȓ\u003c": "2049-05-13T15:27:20.968432454Z" }, - "Username": "qmʎaðƠ绗ʢ緦Hū", - "Subject": "屾Ê窢ɋ鄊qɠ谫ǯǵƕ牀1鞊\\ȹ)" + "Username": "1藍殙菥趏酱Nʎ\u0026^横懋ƶ峦Fïȫƅw", + "Subject": "檾ĩĆ爨4犹|v炩f柏ʒ鴙*鸆偡" }, "requestedAudience": [ - "鉍商OɄƣ圔,xĪɏV鵅砍" + "肯Ûx穞Ƀ", + "ź蕴3ǐ薝Ƅ腲=ʐ诂鱰屾Ê窢ɋ鄊qɠ谫" ], "grantedAudience": [ - "C笜嚯\u003cǐšɚĀĥʋ6鉅\\þc涎漄Ɨ腼" + "ǵƕ牀1鞊\\ȹ)}鉍商OɄƣ圔,xĪ", + "悾xn冏裻摼0Ʈ蚵Ȼ塕»£#稏扟X" ] }, "version": "1" diff --git a/internal/fosite/authorizationcode/authorizationcode_test.go b/internal/fositestorage/authorizationcode/authorizationcode_test.go similarity index 63% rename from internal/fosite/authorizationcode/authorizationcode_test.go rename to internal/fositestorage/authorizationcode/authorizationcode_test.go index d97f2b0c0..09e0e3742 100644 --- a/internal/fosite/authorizationcode/authorizationcode_test.go +++ b/internal/fositestorage/authorizationcode/authorizationcode_test.go @@ -55,56 +55,39 @@ func TestAuthorizeCodeStorage(t *testing.T) { name: "create, get, invalidate standard flow", mocks: nil, run: func(t *testing.T, storage oauth2.AuthorizeCodeStorage) error { - request := &fosite.AuthorizeRequest{ - ResponseTypes: fosite.Arguments{"not-code"}, - RedirectURI: &url.URL{ - Scheme: "", - Opaque: "weee", - User: &url.Userinfo{}, - Host: "", - Path: "/callback", - RawPath: "", - ForceQuery: false, - RawQuery: "", - Fragment: "", - RawFragment: "", - }, - State: "stated", - HandledResponseTypes: fosite.Arguments{"not-type"}, - Request: fosite.Request{ - ID: "abcd-1", - RequestedAt: time.Time{}, - Client: &fosite.DefaultOpenIDConnectClient{ - DefaultClient: &fosite.DefaultClient{ - ID: "pinny", - Secret: nil, - RedirectURIs: nil, - GrantTypes: nil, - ResponseTypes: nil, - Scopes: nil, - Audience: nil, - Public: true, - }, - JSONWebKeysURI: "where", - JSONWebKeys: nil, - TokenEndpointAuthMethod: "something", - RequestURIs: nil, - RequestObjectSigningAlgorithm: "", - TokenEndpointAuthSigningAlgorithm: "", + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, }, - RequestedScope: nil, - GrantedScope: nil, - Form: url.Values{"key": []string{"val"}}, - Session: &openid.DefaultSession{ - Claims: nil, - Headers: nil, - ExpiresAt: nil, - Username: "snorlax", - Subject: "panda", - }, - RequestedAudience: nil, - GrantedAudience: nil, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, } err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) require.NoError(t, err) @@ -118,50 +101,50 @@ func TestAuthorizeCodeStorage(t *testing.T) { wantActions: []coretesting.Action{ coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", ResourceVersion: "", Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", + "storage.pinniped.dev": "authcode", }, }, Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":true,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), "pinniped-storage-version": []byte("1"), }, - Type: "storage.pinniped.dev/authorization-codes", + Type: "storage.pinniped.dev/authcode", }), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", ResourceVersion: "", Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", + "storage.pinniped.dev": "authcode", }, }, Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), "pinniped-storage-version": []byte("1"), }, - Type: "storage.pinniped.dev/authorization-codes", + Type: "storage.pinniped.dev/authcode", }), }, wantSecrets: []corev1.Secret{ { ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", Namespace: namespace, ResourceVersion: "", Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", + "storage.pinniped.dev": "authcode", }, }, Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), "pinniped-storage-version": []byte("1"), }, - Type: "storage.pinniped.dev/authorization-codes", + Type: "storage.pinniped.dev/authcode", }, }, wantErr: "", @@ -210,8 +193,8 @@ func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) { require.Equal(t, validSession.Request, extractedRequest) // checked above - defaultClient := validSession.Request.Request.Client.(*fosite.DefaultOpenIDConnectClient) - defaultSession := validSession.Request.Request.Session.(*openid.DefaultSession) + defaultClient := validSession.Request.Client.(*fosite.DefaultOpenIDConnectClient) + defaultSession := validSession.Request.Session.(*openid.DefaultSession) // makes it easier to use a raw string replacer := strings.NewReplacer("`", "a") @@ -225,10 +208,10 @@ func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) { } } - // deterministic fuzzing of fosite.AuthorizeRequest + // deterministic fuzzing of fosite.Request f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs( // these functions guarantee that these are the only interface types we need to fill out - // if fosite.AuthorizeRequest changes to add more, the fuzzer will panic + // if fosite.Request changes to add more, the fuzzer will panic func(fc *fosite.Client, c fuzz.Continue) { c.Fuzz(defaultClient) *fc = defaultClient diff --git a/internal/fositestorage/pkce/pkce.go b/internal/fositestorage/pkce/pkce.go new file mode 100644 index 000000000..153d93a2c --- /dev/null +++ b/internal/fositestorage/pkce/pkce.go @@ -0,0 +1,115 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "context" + "fmt" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/handler/pkce" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/crud" +) + +const ( + ErrInvalidPKCERequestType = constable.Error("requester must be of type fosite.Request") + + pkceStorageVersion = "1" +) + +var _ pkce.PKCERequestStorage = &pkceStorage{} + +type pkceStorage struct { + storage crud.Storage +} + +type session struct { + Request *fosite.Request `json:"request"` + Version string `json:"version"` +} + +func New(secrets corev1client.SecretInterface) pkce.PKCERequestStorage { + return &pkceStorage{storage: crud.New("pkce", secrets)} +} + +// TODO test what happens when we pass nil as the requester. +func (a *pkceStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { + request, err := validateAndExtractAuthorizeRequest(requester) + if err != nil { + return err + } + + _, err = a.storage.Create(ctx, signature, &session{Request: request, Version: pkceStorageVersion}) + return err +} + +func (a *pkceStorage) GetPKCERequestSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { + session, _, err := a.getSession(ctx, signature) + + if err != nil { + return nil, err + } + + return session.Request, err +} + +func (a *pkceStorage) DeletePKCERequestSession(ctx context.Context, signature string) error { + return a.storage.Delete(ctx, signature) +} + +func (a *pkceStorage) getSession(ctx context.Context, signature string) (*session, string, error) { + session := newValidEmptyPKCESession() + rv, err := a.storage.Get(ctx, signature, session) + + // TODO we do want this + // if errors.IsNotFound(err) { + // return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) + // } + + if err != nil { + return nil, "", fmt.Errorf("failed to get authorization code session for %s: %w", signature, err) + } + + // TODO we probably want this + // if version := session.Version; version != pkceStorageVersion { + // return nil, "", fmt.Errorf("%w: authorization code session for %s has version %s instead of %s", + // ErrInvalidAuthorizeRequestVersion, signature, version, pkceStorageVersion) + // } + + // TODO maybe we want this. it would only apply when a human has edited the secret. + // if session.Request == nil { + // return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData) + // } + + return session, rv, nil +} + +func newValidEmptyPKCESession() *session { + return &session{ + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + }, + } +} + +func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { + request, ok1 := requester.(*fosite.Request) + if !ok1 { + return nil, ErrInvalidPKCERequestType + } + _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) + _, ok3 := request.Session.(*openid.DefaultSession) + + valid := ok2 && ok3 + if !valid { + return nil, ErrInvalidPKCERequestType + } + + return request, nil +} diff --git a/internal/fositestorage/pkce/pkce_test.go b/internal/fositestorage/pkce/pkce_test.go new file mode 100644 index 000000000..82124a9d8 --- /dev/null +++ b/internal/fositestorage/pkce/pkce_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + coretesting "k8s.io/client-go/testing" +) + +func TestPKCEStorage(t *testing.T) { + ctx := context.Background() + secretsGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "secrets", + } + + const namespace = "test-ns" + + wantActions := []coretesting.Action{ + coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + }), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"), + coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", + }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, + } + err := storage.CreatePKCERequestSession(ctx, "fancy-signature", request) + require.NoError(t, err) + + newRequest, err := storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) + + err = storage.DeletePKCERequestSession(ctx, "fancy-signature") + require.NoError(t, err) + + require.Equal(t, wantActions, client.Actions()) +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 23656da2b..6c12394ac 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -17,8 +17,11 @@ import ( "github.com/gorilla/securecookie" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" - "github.com/ory/fosite/storage" "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + kubetesting "k8s.io/client-go/testing" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/oidctestutil" @@ -419,14 +422,12 @@ func TestCallbackEndpoint(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - // Configure fosite the same way that the production code would, except use in-memory storage. + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets("some-namespace") + + // Configure fosite the same way that the production code would. // Inject this into our test subject at the last second so we get a fresh storage for every test. - oauthStore := &storage.MemoryStore{ - Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()}, - AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, - PKCES: map[string]fosite.Requester{}, - IDSessions: map[string]fosite.Requester{}, - } + oauthStore := oidc.NewKubeStorage(secrets) hmacSecret := []byte("some secret - must have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) @@ -471,6 +472,21 @@ func TestCallbackEndpoint(t *testing.T) { authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") require.Len(t, authcodeDataAndSignature, 2) + // Several Secrets should have been created + expectedNumberOfCreatedSecrets := 2 + if test.wantGrantedOpenidScope { + expectedNumberOfCreatedSecrets++ + } + require.Len(t, client.Actions(), expectedNumberOfCreatedSecrets) + + // One authcode should have been stored. + actualAction := client.Actions()[0].(kubetesting.CreateActionImpl) + require.Equal(t, "create", actualAction.GetVerb()) + require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) + actualSecret := actualAction.GetObject().(*corev1.Secret) + require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-authcode-")) + require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage( t, oauthStore, @@ -481,6 +497,14 @@ func TestCallbackEndpoint(t *testing.T) { test.wantDownstreamRequestedScopes, ) + // One PKCE should have been stored. + actualAction = client.Actions()[1].(kubetesting.CreateActionImpl) + require.Equal(t, "create", actualAction.GetVerb()) + require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) + actualSecret = actualAction.GetObject().(*corev1.Secret) + require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-pkce-")) + require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + validatePKCEStorage( t, oauthStore, @@ -491,15 +515,24 @@ func TestCallbackEndpoint(t *testing.T) { test.wantDownstreamPKCEChallengeMethod, ) - validateIDSessionStorage( - t, - oauthStore, - capturedAuthCode, // IDSession store key is full authcode - storedRequestFromAuthcode, - storedSessionFromAuthcode, - test.wantGrantedOpenidScope, - test.wantDownstreamNonce, - ) + // One IDSession should have been stored, if the downstream actually requested the "openid" scope + if test.wantGrantedOpenidScope { + actualAction = client.Actions()[2].(kubetesting.CreateActionImpl) + require.Equal(t, "create", actualAction.GetVerb()) + require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) + actualSecret = actualAction.GetObject().(*corev1.Secret) + require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-idsession-")) + require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + + validateIDSessionStorage( + t, + oauthStore, + capturedAuthCode, // IDSession store key is full authcode + storedRequestFromAuthcode, + storedSessionFromAuthcode, + test.wantDownstreamNonce, + ) + } } }) } @@ -674,7 +707,7 @@ func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string func validateAuthcodeStorage( t *testing.T, - oauthStore *storage.MemoryStore, + oauthStore *oidc.KubeStorage, storeKey string, wantGrantedOpenidScope bool, wantDownstreamIDTokenSubject string, @@ -683,9 +716,6 @@ func validateAuthcodeStorage( ) (*fosite.Request, *openid.DefaultSession) { t.Helper() - // One authcode should have been stored. - require.Len(t, oauthStore.AuthorizeCodes, 1) - // Get the authcode session back from storage so we can require that it was stored correctly. storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil) require.NoError(t, err) @@ -725,7 +755,7 @@ func validateAuthcodeStorage( require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) if wantDownstreamIDTokenGroups != nil { require.Len(t, actualClaims.Extra, 1) - require.Equal(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) + require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) } else { require.Empty(t, actualClaims.Extra) require.NotContains(t, actualClaims.Extra, "groups") @@ -760,7 +790,7 @@ func validateAuthcodeStorage( func validatePKCEStorage( t *testing.T, - oauthStore *storage.MemoryStore, + oauthStore *oidc.KubeStorage, storeKey string, storedRequestFromAuthcode *fosite.Request, storedSessionFromAuthcode *openid.DefaultSession, @@ -768,8 +798,6 @@ func validatePKCEStorage( ) { t.Helper() - // One PKCE should have been stored. - require.Len(t, oauthStore.PKCES, 1) storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil) require.NoError(t, err) @@ -787,33 +815,26 @@ func validatePKCEStorage( func validateIDSessionStorage( t *testing.T, - oauthStore *storage.MemoryStore, + oauthStore *oidc.KubeStorage, storeKey string, storedRequestFromAuthcode *fosite.Request, storedSessionFromAuthcode *openid.DefaultSession, - wantGrantedOpenidScope bool, wantDownstreamNonce string, ) { t.Helper() - // One IDSession should have been stored, if the downstream actually requested the "openid" scope.. - if wantGrantedOpenidScope { - require.Len(t, oauthStore.IDSessions, 1) - storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil) - require.NoError(t, err) + storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil) + require.NoError(t, err) - // Check that storage returned the expected concrete data types. - storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession) + // Check that storage returned the expected concrete data types. + storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession) - // The stored IDSession request should be the same as the stored authcode request. - require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID) - require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession) + // The stored IDSession request should be the same as the stored authcode request. + require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID) + require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession) - // The stored IDSession request should also contain the nonce that the downstream sent us. - require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce")) - } else { - require.Len(t, oauthStore.IDSessions, 0) - } + // The stored IDSession request should also contain the nonce that the downstream sent us. + require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce")) } func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *openid.DefaultSession) { diff --git a/internal/oidc/kube_storage.go b/internal/oidc/kube_storage.go new file mode 100644 index 000000000..664e382d1 --- /dev/null +++ b/internal/oidc/kube_storage.go @@ -0,0 +1,110 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package oidc + +import ( + "context" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/oauth2" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/fositestorage/authorizationcode" +) + +const errKubeStorageNotImplemented = constable.Error("KubeStorage does not implement this method. It should not have been called.") + +type KubeStorage struct { + authorizationCodeStorage oauth2.AuthorizeCodeStorage +} + +func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage { + return &KubeStorage{authorizationCodeStorage: authorizationcode.New(secrets)} +} + +func (KubeStorage) RevokeRefreshToken(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) RevokeAccessToken(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { + return nil +} + +func (KubeStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) { + return errKubeStorageNotImplemented +} + +func (KubeStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { + return nil +} + +func (KubeStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) { + return errKubeStorageNotImplemented +} + +func (KubeStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error { + return nil +} + +func (KubeStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error { + return nil +} + +func (KubeStorage) DeletePKCERequestSession(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (k KubeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, r fosite.Requester) (err error) { + return k.authorizationCodeStorage.CreateAuthorizeCodeSession(ctx, signature, r) +} + +func (k KubeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, s fosite.Session) (request fosite.Requester, err error) { + return k.authorizationCodeStorage.GetAuthorizeCodeSession(ctx, signature, s) +} + +func (k KubeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) { + return k.authorizationCodeStorage.InvalidateAuthorizeCodeSession(ctx, signature) +} + +func (KubeStorage) GetClient(_ context.Context, id string) (fosite.Client, error) { + client := PinnipedCLIOIDCClient() + if client.ID == id { + return client, nil + } + return nil, fosite.ErrNotFound +} + +func (KubeStorage) ClientAssertionJWTValid(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error { + return errKubeStorageNotImplemented +} diff --git a/internal/oidc/nullstorage.go b/internal/oidc/nullstorage.go index 3767f8898..3dcd7a069 100644 --- a/internal/oidc/nullstorage.go +++ b/internal/oidc/nullstorage.go @@ -12,16 +12,16 @@ import ( "go.pinniped.dev/internal/constable" ) -const errNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.") +const errNullStorageNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.") type NullStorage struct{} func (NullStorage) RevokeRefreshToken(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) RevokeAccessToken(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -29,11 +29,11 @@ func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosi } func (NullStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -41,11 +41,11 @@ func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosit } func (NullStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error { @@ -53,15 +53,15 @@ func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fos } func (NullStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error { @@ -69,7 +69,7 @@ func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosit } func (NullStorage) DeletePKCERequestSession(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -77,11 +77,11 @@ func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fos } func (NullStorage) GetAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) InvalidateAuthorizeCodeSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error) { @@ -93,9 +93,9 @@ func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error } func (NullStorage) ClientAssertionJWTValid(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error { - return errNotImplemented + return errNullStorageNotImplemented } diff --git a/test/integration/storage_test.go b/test/integration/storage_test.go index e2f3bdf2b..501099fe3 100644 --- a/test/integration/storage_test.go +++ b/test/integration/storage_test.go @@ -17,7 +17,7 @@ import ( "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "go.pinniped.dev/internal/fosite/authorizationcode" + "go.pinniped.dev/internal/fositestorage/authorizationcode" "go.pinniped.dev/test/library" ) @@ -29,7 +29,7 @@ func TestAuthorizeCodeStorage(t *testing.T) { // randomly generated HMAC authorization code (see below) code = "TQ72B8YjdEOZyxridYbTLE-pzoK4hpdkZxym5j4EmSc.TKRTgQG41IBQ16FDKTthRdhXfLlNaErcMd9Fy47uXAw" // name of the secret that will be created in Kube - name = "pinniped-storage-authorization-codes-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga" + name = "pinniped-storage-authcode-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga" ) hmac := compose.NewOAuth2HMACStrategy(&compose.Config{}, []byte("super-secret-32-byte-for-testing"), nil) From f38c150f6a64a1e6d4597ae4b475c69d08cc70b8 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Tue, 1 Dec 2020 14:53:22 -0800 Subject: [PATCH 33/57] Finished tests for pkce storage and added it to kubestorage - Also fixed some lint errors with v1.33.0 of the linter Signed-off-by: Margo Crawford --- go.mod | 1 + internal/certauthority/certauthority.go | 2 +- .../authorizationcode/authorizationcode.go | 20 +--- .../authorizationcode_test.go | 4 +- internal/fositestorage/fositestorage.go | 34 ++++++ internal/fositestorage/pkce/pkce.go | 53 +++------ internal/fositestorage/pkce/pkce_test.go | 103 +++++++++++++++++- internal/oidc/kube_storage.go | 22 ++-- internal/oidc/oidc.go | 2 +- pkg/oidcclient/login.go | 9 +- test/library/iotest.go | 1 - 11 files changed, 181 insertions(+), 70 deletions(-) create mode 100644 internal/fositestorage/fositestorage.go diff --git a/go.mod b/go.mod index 4277bd9bf..84237fd78 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/gorilla/securecookie v1.1.1 github.com/ory/fosite v0.35.1 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 + github.com/pkg/errors v0.9.1 github.com/sclevine/agouti v3.0.0+incompatible github.com/sclevine/spec v1.4.0 github.com/spf13/cobra v1.0.0 diff --git a/internal/certauthority/certauthority.go b/internal/certauthority/certauthority.go index 13636db4c..6d3cff84c 100644 --- a/internal/certauthority/certauthority.go +++ b/internal/certauthority/certauthority.go @@ -194,7 +194,7 @@ func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time. } // IssuePEM issues a new server certificate for the given identity and duration, returning it as a pair of -// PEM-formatted byte slices for the certificate and private key. +// PEM-formatted byte slices for the certificate and private key. func (c *CA) IssuePEM(subject pkix.Name, dnsNames []string, ttl time.Duration) ([]byte, []byte, error) { return toPEM(c.Issue(subject, dnsNames, nil, ttl)) } diff --git a/internal/fositestorage/authorizationcode/authorizationcode.go b/internal/fositestorage/authorizationcode/authorizationcode.go index e41059b35..0c522985b 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode.go +++ b/internal/fositestorage/authorizationcode/authorizationcode.go @@ -16,10 +16,10 @@ import ( "go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" ) const ( - ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.Request") ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil") ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version") @@ -46,7 +46,7 @@ func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, s // This conversion assumes that we do not wrap the default type in any way // i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation // note that because this type is serialized and stored in Kube, we cannot easily change the implementation later - request, err := validateAndExtractAuthorizeRequest(requester) + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) if err != nil { return err } @@ -140,22 +140,6 @@ func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession { } } -func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { - request, ok1 := requester.(*fosite.Request) - if !ok1 { - return nil, ErrInvalidAuthorizeRequestType - } - _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) - _, ok3 := request.Session.(*openid.DefaultSession) - - valid := ok2 && ok3 - if !valid { - return nil, ErrInvalidAuthorizeRequestType - } - - return request, nil -} - var _ interface { Is(error) bool Unwrap() error diff --git a/internal/fositestorage/authorizationcode/authorizationcode_test.go b/internal/fositestorage/authorizationcode/authorizationcode_test.go index 09e0e3742..9d03995a5 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode_test.go +++ b/internal/fositestorage/authorizationcode/authorizationcode_test.go @@ -25,6 +25,8 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/kubernetes/fake" coretesting "k8s.io/client-go/testing" + + "go.pinniped.dev/internal/fositestorage" ) func TestAuthorizeCodeStorage(t *testing.T) { @@ -188,7 +190,7 @@ func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) { validSession := NewValidEmptyAuthorizeCodeSession() // sanity check our valid session - extractedRequest, err := validateAndExtractAuthorizeRequest(validSession.Request) + extractedRequest, err := fositestorage.ValidateAndExtractAuthorizeRequest(validSession.Request) require.NoError(t, err) require.Equal(t, validSession.Request, extractedRequest) diff --git a/internal/fositestorage/fositestorage.go b/internal/fositestorage/fositestorage.go new file mode 100644 index 000000000..d23c9f6a6 --- /dev/null +++ b/internal/fositestorage/fositestorage.go @@ -0,0 +1,34 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package fositestorage + +import ( + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + + "go.pinniped.dev/internal/constable" +) + +const ( + ErrInvalidRequestType = constable.Error("requester must be of type fosite.Request") + ErrInvalidClientType = constable.Error("requester's client must be of type fosite.DefaultOpenIDConnectClient") + ErrInvalidSessionType = constable.Error("requester's session must be of type openid.DefaultSession") +) + +func ValidateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { + request, ok1 := requester.(*fosite.Request) + if !ok1 { + return nil, ErrInvalidRequestType + } + _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) + if !ok2 { + return nil, ErrInvalidClientType + } + _, ok3 := request.Session.(*openid.DefaultSession) + if !ok3 { + return nil, ErrInvalidSessionType + } + + return request, nil +} diff --git a/internal/fositestorage/pkce/pkce.go b/internal/fositestorage/pkce/pkce.go index 153d93a2c..9e8ef3d58 100644 --- a/internal/fositestorage/pkce/pkce.go +++ b/internal/fositestorage/pkce/pkce.go @@ -10,14 +10,17 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/pkce" + "k8s.io/apimachinery/pkg/api/errors" corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" ) const ( - ErrInvalidPKCERequestType = constable.Error("requester must be of type fosite.Request") + ErrInvalidPKCERequestVersion = constable.Error("pkce request data has wrong version") + ErrInvalidPKCERequestData = constable.Error("pkce request data must be present") pkceStorageVersion = "1" ) @@ -37,9 +40,8 @@ func New(secrets corev1client.SecretInterface) pkce.PKCERequestStorage { return &pkceStorage{storage: crud.New("pkce", secrets)} } -// TODO test what happens when we pass nil as the requester. func (a *pkceStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { - request, err := validateAndExtractAuthorizeRequest(requester) + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) if err != nil { return err } @@ -66,25 +68,22 @@ func (a *pkceStorage) getSession(ctx context.Context, signature string) (*sessio session := newValidEmptyPKCESession() rv, err := a.storage.Get(ctx, signature, session) - // TODO we do want this - // if errors.IsNotFound(err) { - // return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) - // } - - if err != nil { - return nil, "", fmt.Errorf("failed to get authorization code session for %s: %w", signature, err) + if errors.IsNotFound(err) { + return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) } - // TODO we probably want this - // if version := session.Version; version != pkceStorageVersion { - // return nil, "", fmt.Errorf("%w: authorization code session for %s has version %s instead of %s", - // ErrInvalidAuthorizeRequestVersion, signature, version, pkceStorageVersion) - // } + if err != nil { + return nil, "", fmt.Errorf("failed to get pkce session for %s: %w", signature, err) + } - // TODO maybe we want this. it would only apply when a human has edited the secret. - // if session.Request == nil { - // return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData) - // } + if version := session.Version; version != pkceStorageVersion { + return nil, "", fmt.Errorf("%w: pkce session for %s has version %s instead of %s", + ErrInvalidPKCERequestVersion, signature, version, pkceStorageVersion) + } + + if session.Request.ID == "" { + return nil, "", fmt.Errorf("malformed pkce session for %s: %w", signature, ErrInvalidPKCERequestData) + } return session, rv, nil } @@ -97,19 +96,3 @@ func newValidEmptyPKCESession() *session { }, } } - -func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { - request, ok1 := requester.(*fosite.Request) - if !ok1 { - return nil, ErrInvalidPKCERequestType - } - _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) - _, ok3 := request.Session.(*openid.DefaultSession) - - valid := ok2 && ok3 - if !valid { - return nil, ErrInvalidPKCERequestType - } - - return request, nil -} diff --git a/internal/fositestorage/pkce/pkce_test.go b/internal/fositestorage/pkce/pkce_test.go index 82124a9d8..80b2d9ddf 100644 --- a/internal/fositestorage/pkce/pkce_test.go +++ b/internal/fositestorage/pkce/pkce_test.go @@ -11,6 +11,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -19,6 +20,8 @@ import ( coretesting "k8s.io/client-go/testing" ) +const namespace = "test-ns" + func TestPKCEStorage(t *testing.T) { ctx := context.Background() secretsGVR := schema.GroupVersionResource{ @@ -27,8 +30,6 @@ func TestPKCEStorage(t *testing.T) { Resource: "secrets", } - const namespace = "test-ns" - wantActions := []coretesting.Action{ coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -98,3 +99,101 @@ func TestPKCEStorage(t *testing.T) { require.Equal(t, wantActions, client.Actions()) } + +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetPKCERequestSession(ctx, "non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + + require.EqualError(t, err, "pkce request data has wrong version: pkce session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + } + + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "malformed pkce session for fancy-signature: pkce request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") +} diff --git a/internal/oidc/kube_storage.go b/internal/oidc/kube_storage.go index 664e382d1..388cd7719 100644 --- a/internal/oidc/kube_storage.go +++ b/internal/oidc/kube_storage.go @@ -7,6 +7,10 @@ import ( "context" "time" + fositepkce "github.com/ory/fosite/handler/pkce" + + "go.pinniped.dev/internal/fositestorage/pkce" + "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" corev1client "k8s.io/client-go/kubernetes/typed/core/v1" @@ -19,10 +23,14 @@ const errKubeStorageNotImplemented = constable.Error("KubeStorage does not imple type KubeStorage struct { authorizationCodeStorage oauth2.AuthorizeCodeStorage + pkceStorage fositepkce.PKCERequestStorage } func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage { - return &KubeStorage{authorizationCodeStorage: authorizationcode.New(secrets)} + return &KubeStorage{ + authorizationCodeStorage: authorizationcode.New(secrets), + pkceStorage: pkce.New(secrets), + } } func (KubeStorage) RevokeRefreshToken(_ context.Context, _ string) error { @@ -69,16 +77,16 @@ func (KubeStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error return errKubeStorageNotImplemented } -func (KubeStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) { - return nil, errKubeStorageNotImplemented +func (k KubeStorage) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return k.pkceStorage.GetPKCERequestSession(ctx, signature, session) } -func (KubeStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error { - return nil +func (k KubeStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { + return k.pkceStorage.CreatePKCERequestSession(ctx, signature, requester) } -func (KubeStorage) DeletePKCERequestSession(_ context.Context, _ string) error { - return errKubeStorageNotImplemented +func (k KubeStorage) DeletePKCERequestSession(ctx context.Context, signature string) error { + return k.pkceStorage.DeletePKCERequestSession(ctx, signature) } func (k KubeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, r fosite.Requester) (err error) { diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 1241444f6..b129714b1 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -35,7 +35,7 @@ const ( UpstreamStateParamEncodingName = "s" // CSRFCookieName is the name of the browser cookie which shall hold our CSRF value. - // The `__Host` prefix has a special meaning. See + // The `__Host` prefix has a special meaning. See: // https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes. CSRFCookieName = "__Host-pinniped-csrf" diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 2e286efa3..fbbe23a91 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -89,10 +89,11 @@ func WithContext(ctx context.Context) Option { // WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the // authorization code callback. By default, a random high port will be chosen which requires the authorization server // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252: -// The authorization server MUST allow any port to be specified at the -// time of the request for loopback IP redirect URIs, to accommodate -// clients that obtain an available ephemeral port from the operating -// system at the time of the request. +// +// The authorization server MUST allow any port to be specified at the +// time of the request for loopback IP redirect URIs, to accommodate +// clients that obtain an available ephemeral port from the operating +// system at the time of the request. func WithListenPort(port uint16) Option { return func(h *handlerState) error { h.listenAddr = fmt.Sprintf("localhost:%d", port) diff --git a/test/library/iotest.go b/test/library/iotest.go index dcb0e6959..daf2ed4e8 100644 --- a/test/library/iotest.go +++ b/test/library/iotest.go @@ -33,7 +33,6 @@ func (l *testlogReader) Read(p []byte) (n int, err error) { return } -//nolint: gochecknoglobals var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) func maskTokens(in []byte) string { From d60c184424f60b751e317249551cfe33890470e7 Mon Sep 17 00:00:00 2001 From: Margo Crawford Date: Tue, 1 Dec 2020 17:18:32 -0800 Subject: [PATCH 34/57] Add pkce and openidconnect storage - Also refactor authorizationcode_test Signed-off-by: Ryan Richard --- .../authorizationcode/authorizationcode.go | 4 +- .../authorizationcode_test.go | 309 ++++++++++-------- .../openidconnect/openidconnect.go | 124 +++++++ .../openidconnect/openidconnect_test.go | 209 ++++++++++++ .../oidc/callback/callback_handler_test.go | 43 ++- internal/oidc/kube_storage.go | 22 +- test/library/iotest.go | 1 + 7 files changed, 546 insertions(+), 166 deletions(-) create mode 100644 internal/fositestorage/openidconnect/openidconnect.go create mode 100644 internal/fositestorage/openidconnect/openidconnect_test.go diff --git a/internal/fositestorage/authorizationcode/authorizationcode.go b/internal/fositestorage/authorizationcode/authorizationcode.go index 0c522985b..8aca618ad 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode.go +++ b/internal/fositestorage/authorizationcode/authorizationcode.go @@ -20,7 +20,7 @@ import ( ) const ( - ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil") + ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must be present") ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version") authorizeCodeStorageVersion = "1" @@ -119,7 +119,7 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string) ErrInvalidAuthorizeRequestVersion, signature, version, authorizeCodeStorageVersion) } - if session.Request == nil { + if session.Request.ID == "" { return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData) } diff --git a/internal/fositestorage/authorizationcode/authorizationcode_test.go b/internal/fositestorage/authorizationcode/authorizationcode_test.go index 9d03995a5..38f3e1a95 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode_test.go +++ b/internal/fositestorage/authorizationcode/authorizationcode_test.go @@ -16,8 +16,8 @@ import ( fuzz "github.com/google/gofuzz" "github.com/ory/fosite" - "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" corev1 "k8s.io/api/core/v1" @@ -29,7 +29,9 @@ import ( "go.pinniped.dev/internal/fositestorage" ) -func TestAuthorizeCodeStorage(t *testing.T) { +const namespace = "test-ns" + +func TestAuthorizationCodeStorage(t *testing.T) { ctx := context.Background() secretsGVR := schema.GroupVersionResource{ Group: "", @@ -37,151 +39,186 @@ func TestAuthorizeCodeStorage(t *testing.T) { Resource: "secrets", } - const namespace = "test-ns" - - type mocker interface { - AddReactor(verb, resource string, reaction coretesting.ReactionFunc) - PrependReactor(verb, resource string, reaction coretesting.ReactionFunc) - Tracker() coretesting.ObjectTracker - } - - tests := []struct { - name string - mocks func(*testing.T, mocker) - run func(*testing.T, oauth2.AuthorizeCodeStorage) error - wantActions []coretesting.Action - wantSecrets []corev1.Secret - wantErr string - }{ - { - name: "create, get, invalidate standard flow", - mocks: nil, - run: func(t *testing.T, storage oauth2.AuthorizeCodeStorage) error { - request := &fosite.Request{ - ID: "abcd-1", - RequestedAt: time.Time{}, - Client: &fosite.DefaultOpenIDConnectClient{ - DefaultClient: &fosite.DefaultClient{ - ID: "pinny", - Secret: nil, - RedirectURIs: nil, - GrantTypes: nil, - ResponseTypes: nil, - Scopes: nil, - Audience: nil, - Public: true, - }, - JSONWebKeysURI: "where", - JSONWebKeys: nil, - TokenEndpointAuthMethod: "something", - RequestURIs: nil, - RequestObjectSigningAlgorithm: "", - TokenEndpointAuthSigningAlgorithm: "", - }, - RequestedScope: nil, - GrantedScope: nil, - Form: url.Values{"key": []string{"val"}}, - Session: &openid.DefaultSession{ - Claims: nil, - Headers: nil, - ExpiresAt: nil, - Username: "snorlax", - Subject: "panda", - }, - RequestedAudience: nil, - GrantedAudience: nil, - } - err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) - require.NoError(t, err) - - newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) - require.NoError(t, err) - require.Equal(t, request, newRequest) - - return storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") - }, - wantActions: []coretesting.Action{ - coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authcode", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authcode", - }), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), - coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authcode", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authcode", - }), - }, - wantSecrets: []corev1.Secret{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", - Namespace: namespace, - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authcode", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authcode", + wantActions := []coretesting.Action{ + coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", }, }, - wantErr: "", + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + }), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + }), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() + err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) + require.NoError(t, err) - client := fake.NewSimpleClientset() - if tt.mocks != nil { - tt.mocks(t, client) - } - secrets := client.CoreV1().Secrets(namespace) - storage := New(secrets) + newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) - err := tt.run(t, storage) + err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") + require.NoError(t, err) - require.Equal(t, tt.wantErr, errString(err)) - require.Equal(t, tt.wantActions, client.Actions()) - - actualSecrets, err := secrets.List(ctx, metav1.ListOptions{}) - require.NoError(t, err) - require.Equal(t, tt.wantSecrets, actualSecrets.Items) - }) - } + require.Equal(t, wantActions, client.Actions()) } -func errString(err error) string { - if err == nil { - return "" +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetAuthorizeCodeSession(ctx, "non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version", "active": true}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + + require.EqualError(t, err, "authorization request data has wrong version: authorization code session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value", "version":"1", "active": true}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", } - return err.Error() + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "malformed authorization code session for fancy-signature: authorization request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") } // TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session. diff --git a/internal/fositestorage/openidconnect/openidconnect.go b/internal/fositestorage/openidconnect/openidconnect.go new file mode 100644 index 000000000..797d21a81 --- /dev/null +++ b/internal/fositestorage/openidconnect/openidconnect.go @@ -0,0 +1,124 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package openidconnect + +import ( + "context" + "fmt" + "strings" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "k8s.io/apimachinery/pkg/api/errors" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" +) + +const ( + ErrInvalidOIDCRequestVersion = constable.Error("oidc request data has wrong version") + ErrInvalidOIDCRequestData = constable.Error("oidc request data must be present") + ErrMalformedAuthorizationCode = constable.Error("malformed authorization code") + + oidcStorageVersion = "1" +) + +var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{} + +type openIDConnectRequestStorage struct { + storage crud.Storage +} + +type session struct { + Request *fosite.Request `json:"request"` + Version string `json:"version"` +} + +func New(secrets corev1client.SecretInterface) openid.OpenIDConnectRequestStorage { + return &openIDConnectRequestStorage{storage: crud.New("oidc", secrets)} +} + +func (a *openIDConnectRequestStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error { + signature, err := getSignature(authcode) + if err != nil { + return err + } + + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) + if err != nil { + return err + } + + _, err = a.storage.Create(ctx, signature, &session{Request: request, Version: oidcStorageVersion}) + return err +} + +func (a *openIDConnectRequestStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, _ fosite.Requester) (fosite.Requester, error) { + signature, err := getSignature(authcode) + if err != nil { + return nil, err + } + + session, _, err := a.getSession(ctx, signature) + + if err != nil { + return nil, err + } + + return session.Request, err +} + +func (a *openIDConnectRequestStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error { + signature, err := getSignature(authcode) + if err != nil { + return err + } + + return a.storage.Delete(ctx, signature) +} + +func (a *openIDConnectRequestStorage) getSession(ctx context.Context, signature string) (*session, string, error) { + session := newValidEmptyOIDCSession() + rv, err := a.storage.Get(ctx, signature, session) + + if errors.IsNotFound(err) { + return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) + } + + if err != nil { + return nil, "", fmt.Errorf("failed to get oidc session for %s: %w", signature, err) + } + + if version := session.Version; version != oidcStorageVersion { + return nil, "", fmt.Errorf("%w: oidc session for %s has version %s instead of %s", + ErrInvalidOIDCRequestVersion, signature, version, oidcStorageVersion) + } + + if session.Request.ID == "" { + return nil, "", fmt.Errorf("malformed oidc session for %s: %w", signature, ErrInvalidOIDCRequestData) + } + + return session, rv, nil +} + +func newValidEmptyOIDCSession() *session { + return &session{ + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + }, + } +} + +func getSignature(authorizationCode string) (string, error) { + split := strings.Split(authorizationCode, ".") + + if len(split) != 2 { + return "", ErrMalformedAuthorizationCode + } + + return split[1], nil +} diff --git a/internal/fositestorage/openidconnect/openidconnect_test.go b/internal/fositestorage/openidconnect/openidconnect_test.go new file mode 100644 index 000000000..976828ed3 --- /dev/null +++ b/internal/fositestorage/openidconnect/openidconnect_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package openidconnect + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + coretesting "k8s.io/client-go/testing" +) + +const namespace = "test-ns" + +func TestOpenIdConnectStorage(t *testing.T) { + ctx := context.Background() + secretsGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "secrets", + } + + wantActions := []coretesting.Action{ + coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + }), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"), + coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", + }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, + } + err := storage.CreateOpenIDConnectSession(ctx, "fancy-code.fancy-signature", request) + require.NoError(t, err) + + newRequest, err := storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) + + err = storage.DeleteOpenIDConnectSession(ctx, "fancy-code.fancy-signature") + require.NoError(t, err) + + require.Equal(t, wantActions, client.Actions()) +} + +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetOpenIDConnectSession(ctx, "authcode.non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + + require.EqualError(t, err, "oidc request data has wrong version: oidc session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + } + + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + require.EqualError(t, err, "malformed oidc session for fancy-signature: oidc request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") +} + +func TestAuthcodeHasNoDot(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateOpenIDConnectSession(ctx, "all-one-part", nil) + require.EqualError(t, err, "malformed authorization code") +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 6c12394ac..e73eb7594 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -479,13 +479,18 @@ func TestCallbackEndpoint(t *testing.T) { } require.Len(t, client.Actions(), expectedNumberOfCreatedSecrets) + actualSecretNames := []string{} + for i := range client.Actions() { + actualAction := client.Actions()[i].(kubetesting.CreateActionImpl) + require.Equal(t, "create", actualAction.GetVerb()) + require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) + actualSecret := actualAction.GetObject().(*corev1.Secret) + require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + actualSecretNames = append(actualSecretNames, actualSecret.Name) + } + // One authcode should have been stored. - actualAction := client.Actions()[0].(kubetesting.CreateActionImpl) - require.Equal(t, "create", actualAction.GetVerb()) - require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) - actualSecret := actualAction.GetObject().(*corev1.Secret) - require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-authcode-")) - require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-authcode-") storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage( t, @@ -498,12 +503,7 @@ func TestCallbackEndpoint(t *testing.T) { ) // One PKCE should have been stored. - actualAction = client.Actions()[1].(kubetesting.CreateActionImpl) - require.Equal(t, "create", actualAction.GetVerb()) - require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) - actualSecret = actualAction.GetObject().(*corev1.Secret) - require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-pkce-")) - require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-pkce-") validatePKCEStorage( t, @@ -517,12 +517,7 @@ func TestCallbackEndpoint(t *testing.T) { // One IDSession should have been stored, if the downstream actually requested the "openid" scope if test.wantGrantedOpenidScope { - actualAction = client.Actions()[2].(kubetesting.CreateActionImpl) - require.Equal(t, "create", actualAction.GetVerb()) - require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) - actualSecret = actualAction.GetObject().(*corev1.Secret) - require.True(t, strings.HasPrefix(actualSecret.Name, "pinniped-storage-idsession-")) - require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-oidc") validateIDSessionStorage( t, @@ -847,3 +842,15 @@ func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requ return storedRequest, storedSession } + +func requireAnyStringHasPrefix(t *testing.T, stringList []string, prefix string) { + t.Helper() + + containsPrefix := false + for i := range stringList { + if strings.HasPrefix(stringList[i], prefix) { + containsPrefix = true + } + } + require.Truef(t, containsPrefix, "list %v did not contain any strings with prefix %s", stringList, prefix) +} diff --git a/internal/oidc/kube_storage.go b/internal/oidc/kube_storage.go index 388cd7719..405a6ade8 100644 --- a/internal/oidc/kube_storage.go +++ b/internal/oidc/kube_storage.go @@ -7,16 +7,16 @@ import ( "context" "time" - fositepkce "github.com/ory/fosite/handler/pkce" - - "go.pinniped.dev/internal/fositestorage/pkce" - "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/openid" + fositepkce "github.com/ory/fosite/handler/pkce" corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/fositestorage/authorizationcode" + "go.pinniped.dev/internal/fositestorage/openidconnect" + "go.pinniped.dev/internal/fositestorage/pkce" ) const errKubeStorageNotImplemented = constable.Error("KubeStorage does not implement this method. It should not have been called.") @@ -24,12 +24,14 @@ const errKubeStorageNotImplemented = constable.Error("KubeStorage does not imple type KubeStorage struct { authorizationCodeStorage oauth2.AuthorizeCodeStorage pkceStorage fositepkce.PKCERequestStorage + oidcStorage openid.OpenIDConnectRequestStorage } func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage { return &KubeStorage{ authorizationCodeStorage: authorizationcode.New(secrets), pkceStorage: pkce.New(secrets), + oidcStorage: openidconnect.New(secrets), } } @@ -65,16 +67,16 @@ func (KubeStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err er return errKubeStorageNotImplemented } -func (KubeStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error { - return nil +func (k KubeStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error { + return k.oidcStorage.CreateOpenIDConnectSession(ctx, authcode, requester) } -func (KubeStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) { - return nil, errKubeStorageNotImplemented +func (k KubeStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) (fosite.Requester, error) { + return k.oidcStorage.GetOpenIDConnectSession(ctx, authcode, requester) } -func (KubeStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error { - return errKubeStorageNotImplemented +func (k KubeStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error { + return k.oidcStorage.DeleteOpenIDConnectSession(ctx, authcode) } func (k KubeStorage) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { diff --git a/test/library/iotest.go b/test/library/iotest.go index daf2ed4e8..dcb0e6959 100644 --- a/test/library/iotest.go +++ b/test/library/iotest.go @@ -33,6 +33,7 @@ func (l *testlogReader) Read(p []byte) (n int, err error) { return } +//nolint: gochecknoglobals var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) func maskTokens(in []byte) string { From c23c54f50033637396e51c7bc03748936e31c3ad Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Tue, 1 Dec 2020 17:01:22 -0600 Subject: [PATCH 35/57] Add an explicit `Path=/;` to our CSRF cookie, per the spec. > [...] a cookie named "__Host-cookie1" MUST contain a "Path" attribute with a value of "/". https://tools.ietf.org/html/draft-ietf-httpbis-cookie-prefixes-00#section-3.2 Signed-off-by: Matt Moyer --- internal/oidc/auth/auth_handler.go | 1 + internal/oidc/auth/auth_handler_test.go | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 57bd548f3..f3a305f18 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -227,6 +227,7 @@ func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: true, + Path: "/", }) return nil diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index b56ff262c..73176d082 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -301,7 +301,7 @@ func TestAuthorizationEndpoint(t *testing.T) { cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, - csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, + csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + " ", wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")), @@ -751,7 +751,7 @@ func TestAuthorizationEndpoint(t *testing.T) { if test.wantCSRFValueInCookieHeader != "" { require.Len(t, rsp.Header().Values("Set-Cookie"), 1) actualCookie := rsp.Header().Get("Set-Cookie") - regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict") + regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); Path=/; HttpOnly; Secure; SameSite=Strict") submatches := regex.FindStringSubmatch(actualCookie) require.Len(t, submatches, 2) captured := submatches[1] From 4fe691de920a6a9aaff338704f06d61ce2c2a680 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 10:27:20 -0600 Subject: [PATCH 36/57] Save an http.Client with each upstreamoidc.ProviderConfig object. This allows the token exchange request to be performed with the correct TLS configuration. We go to a bit of extra work to make sure the `http.Client` object is cached between reconcile operations so that connection pooling works as expected. Signed-off-by: Matt Moyer --- .../upstreamwatcher/upstreamwatcher.go | 34 ++++++++++++------- internal/upstreamoidc/upstreamoidc.go | 9 ++--- pkg/oidcclient/login.go | 6 ++-- pkg/oidcclient/login_test.go | 8 ++--- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go index 4955646b1..7faa4d9c2 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go @@ -70,15 +70,21 @@ type IDPCache interface { // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. type lruValidatorCache struct{ cache *cache.Expiring } -func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider { - if result, ok := c.cache.Get(c.cacheKey(spec)); ok { - return result.(*oidc.Provider) - } - return nil +type lruValidatorCacheEntry struct { + provider *oidc.Provider + client *http.Client } -func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) { - c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL) +func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) { + if result, ok := c.cache.Get(c.cacheKey(spec)); ok { + entry := result.(*lruValidatorCacheEntry) + return entry.provider, entry.client + } + return nil, nil +} + +func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) { + c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL) } func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} { @@ -97,8 +103,8 @@ type controller struct { providers idpinformers.UpstreamOIDCProviderInformer secrets corev1informers.SecretInformer validatorCache interface { - getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider - putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) + getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) + putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client) } } @@ -224,6 +230,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res // If everything is valid, update the result and set the condition to true. result.Config.ClientID = string(clientID) + result.Config.ClientSecret = string(clientSecret) return &v1alpha1.Condition{ Type: typeClientCredsValid, Status: v1alpha1.ConditionTrue, @@ -234,8 +241,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { - // Get the provider (from cache if possible). - discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) + // Get the provider and HTTP Client from cache if possible. + discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec) // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. if discoveredProvider == nil { @@ -248,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst Message: err.Error(), } } - httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} + httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer) if err != nil { @@ -261,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst } // Update the cache with the newly discovered value. - c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) + c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient) } // Parse out and validate the discovered authorize endpoint. @@ -286,6 +293,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst // If everything is valid, update the result and set the condition to true. result.Config.Endpoint = discoveredProvider.Endpoint() result.Provider = discoveredProvider + result.Client = httpClient return &v1alpha1.Condition{ Type: typeOIDCDiscoverySucceeded, Status: v1alpha1.ConditionTrue, diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 2957e9e30..4af7efdb7 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -20,8 +20,8 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) -func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { - return &ProviderConfig{Config: config, Provider: provider} +func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { + return &ProviderConfig{Config: config, Provider: provider, Client: client} } // ProviderConfig holds the active configuration of an upstream OIDC provider. @@ -33,6 +33,7 @@ type ProviderConfig struct { Provider interface { Verifier(*oidc.Config) *oidc.IDTokenVerifier } + Client *http.Client } func (p *ProviderConfig) GetName() string { @@ -61,7 +62,7 @@ func (p *ProviderConfig) GetGroupsClaim() string { } func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier()) + tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier()) if err != nil { return oidctypes.Token{}, nil, err } @@ -74,7 +75,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e if !hasIDTok { return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") } - validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok) + validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok) if err != nil { return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index fbbe23a91..2b21e0807 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -64,7 +64,7 @@ type handlerState struct { generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error - getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI + getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI callbacks chan callbackResult } @@ -295,7 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "") + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") if err != nil { return nil, err } @@ -328,7 +328,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 280dfd0ad..374d90e3f 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -238,7 +238,7 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). @@ -277,7 +277,7 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). @@ -522,7 +522,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantHTTPStatus: http.StatusBadRequest, opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). @@ -538,7 +538,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { query: "state=test-state&code=valid", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). From fde56164cd374a35ce5e977ebef0b922e60daa3f Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 10:36:07 -0600 Subject: [PATCH 37/57] Add a `redirectURI` parameter to ExchangeAuthcodeAndValidateTokens() method. We missed this in the original interface specification, but the `grant_type=authorization_code` requires it, per RFC6749 (https://tools.ietf.org/html/rfc6749#section-4.1.3). Signed-off-by: Matt Moyer --- .../mockupstreamoidcidentityprovider.go | 8 ++++---- internal/oidc/callback/callback_handler.go | 2 ++ .../oidc/callback/callback_handler_test.go | 5 ++++- internal/oidc/oidctestutil/oidc.go | 3 +++ .../provider/dynamic_upstream_idp_provider.go | 1 + internal/oidc/provider/manager/manager.go | 19 +++++++++++++++++-- internal/upstreamoidc/upstreamoidc.go | 9 +++++++-- internal/upstreamoidc/upstreamoidc_test.go | 2 +- pkg/oidcclient/login.go | 9 ++++++++- pkg/oidcclient/login_test.go | 8 ++++++-- 10 files changed, 53 insertions(+), 13 deletions(-) diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index e3887b827..93085f4bd 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -43,9 +43,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityPr } // ExchangeAuthcodeAndValidateTokens mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(oidctypes.Token) ret1, _ := ret[1].(map[string]interface{}) ret2, _ := ret[2].(error) @@ -53,9 +53,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(ar } // ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens -func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4) } // GetAuthorizationURL mocks base method diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index f237726b6..4add765e6 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -42,6 +42,7 @@ func NewHandler( idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder, + redirectURI string, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) @@ -77,6 +78,7 @@ func NewHandler( authcode(r), state.PKCECode, state.Nonce, + redirectURI, ) if err != nil { plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index e73eb7594..ead11693c 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -43,6 +43,8 @@ const ( happyUpstreamAuthcode = "upstream-auth-code" + happyUpstreamRedirectURI = "https://example.com/callback" + happyDownstreamState = "some-downstream-state-with-at-least-32-bytes" happyDownstreamCSRF = "test-csrf" happyDownstreamPKCE = "test-pkce" @@ -105,6 +107,7 @@ func TestCallbackEndpoint(t *testing.T) { Authcode: happyUpstreamAuthcode, PKCECodeVerifier: pkce.Code(happyDownstreamPKCE), ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), + RedirectURI: happyUpstreamRedirectURI, } // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it @@ -433,7 +436,7 @@ func TestCallbackEndpoint(t *testing.T) { oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) - subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index eafd567f5..5b214e5c5 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -24,6 +24,7 @@ type ExchangeAuthcodeAndValidateTokenArgs struct { Authcode string PKCECodeVerifier pkce.Code ExpectedIDTokenNonce nonce.Nonce + RedirectURI string } type TestUpstreamOIDCIdentityProvider struct { @@ -73,6 +74,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, + redirectURI string, ) (oidctypes.Token, map[string]interface{}, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) @@ -83,6 +85,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( Authcode: authcode, PKCECodeVerifier: pkceCodeVerifier, ExpectedIDTokenNonce: expectedIDTokenNonce, + RedirectURI: redirectURI, }) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 8ef1e5dbb..be25ffe81 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -42,6 +42,7 @@ type UpstreamOIDCIdentityProviderI interface { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, + redirectURI string, ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 42c828c16..6bac2c600 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -83,10 +83,25 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { encoder.SetSerializer(securecookie.JSONEncoder{}) authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) + m.providerHandlers[authURL] = auth.NewHandler( + incomingProvider.Issuer(), + m.idpListGetter, + oauthHelper, + csrftoken.Generate, + pkce.Generate, + nonce.Generate, + encoder, + encoder, + ) callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath - m.providerHandlers[callbackURL] = callback.NewHandler(m.idpListGetter, oauthHelper, encoder, encoder) + m.providerHandlers[callbackURL] = callback.NewHandler( + m.idpListGetter, + oauthHelper, + encoder, + encoder, + incomingProvider.Issuer()+oidc.CallbackEndpointPath, + ) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 4af7efdb7..a789cb85a 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -61,8 +61,13 @@ func (p *ProviderConfig) GetGroupsClaim() string { return p.GroupsClaim } -func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier()) +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) { + tok, err := p.Config.Exchange( + oidc.ClientContext(ctx, p.Client), + authcode, + pkceCodeVerifier.Verifier(), + oauth2.SetAuthURLParam("redirect_uri", redirectURI), + ) if err != nil { return oidctypes.Token{}, nil, err } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 3f3eed2e1..541d502fc 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -181,7 +181,7 @@ func TestProviderConfig(t *testing.T) { ctx := context.Background() - tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce) + tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) require.Equal(t, oidctypes.Token{}, tok) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 2b21e0807..0df346228 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -328,7 +328,14 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). + ExchangeAuthcodeAndValidateTokens( + r.Context(), + params.Get("code"), + h.pkce, + h.nonce, + h.oauth2Config.RedirectURL, + ) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 374d90e3f..96d790ba1 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -488,6 +488,8 @@ func TestLogin(t *testing.T) { } func TestHandleAuthCodeCallback(t *testing.T) { + const testRedirectURI = "http://127.0.0.1:12324/callback" + tests := []struct { name string method string @@ -522,10 +524,11 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantHTTPStatus: http.StatusBadRequest, opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) return mock } @@ -538,10 +541,11 @@ func TestHandleAuthCodeCallback(t *testing.T) { query: "state=test-state&code=valid", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) return mock } From fe0481c30432188b5d491dbcd52b011a60f4aee8 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 10:47:01 -0600 Subject: [PATCH 38/57] In integration test env, deploy a ClusterIP service and register that with Dex. Signed-off-by: Matt Moyer --- hack/lib/tilt/Tiltfile | 1 + hack/prepare-for-integration-tests.sh | 3 ++- test/deploy/dex/dex.yaml | 3 +-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hack/lib/tilt/Tiltfile b/hack/lib/tilt/Tiltfile index e657e9676..0c176e2c8 100644 --- a/hack/lib/tilt/Tiltfile +++ b/hack/lib/tilt/Tiltfile @@ -103,6 +103,7 @@ k8s_yaml(local([ '--data-value-yaml', 'service_http_nodeport_nodeport=31234', '--data-value-yaml', 'service_https_nodeport_port=443', '--data-value-yaml', 'service_https_nodeport_nodeport=31243', + '--data-value-yaml', 'service_https_clusterip_port=443', '--data-value-yaml', 'custom_labels={mySupervisorCustomLabelName: mySupervisorCustomLabelValue}', ])) # Tell tilt to watch all of those files for changes. diff --git a/hack/prepare-for-integration-tests.sh b/hack/prepare-for-integration-tests.sh index 11e1fbf8b..97bdcceb7 100755 --- a/hack/prepare-for-integration-tests.sh +++ b/hack/prepare-for-integration-tests.sh @@ -230,6 +230,7 @@ if ! tilt_mode; then --data-value-yaml 'service_http_nodeport_nodeport=31234' \ --data-value-yaml 'service_https_nodeport_port=443' \ --data-value-yaml 'service_https_nodeport_nodeport=31243' \ + --data-value-yaml 'service_https_clusterip_port=443' \ >"$manifest" kapp deploy --yes --app "$supervisor_app_name" --diff-changes --file "$manifest" @@ -302,7 +303,7 @@ export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER=https://dex.dex.svc.cluster export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER_CA_BUNDLE="${test_ca_bundle_pem}" export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_ID=pinniped-supervisor export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_SECRET=pinniped-supervisor-secret -export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://127.0.0.1:12345/some/path/callback +export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME=pinny@example.com export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD=password diff --git a/test/deploy/dex/dex.yaml b/test/deploy/dex/dex.yaml index bd078f249..6a5ecfecc 100644 --- a/test/deploy/dex/dex.yaml +++ b/test/deploy/dex/dex.yaml @@ -28,8 +28,7 @@ staticClients: name: 'Pinniped Supervisor' secret: pinniped-supervisor-secret redirectURIs: - - #@ "http://127.0.0.1:" + str(data.values.ports.cli) + "/callback" - - #@ "http://[::1]:" + str(data.values.ports.cli) + "/callback" + - https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback enablePasswordDB: true staticPasswords: - username: "pinny" From 22953cdb7849201ced46a4b3c9a56e1cd4be865a Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 14:33:07 -0600 Subject: [PATCH 39/57] Add a CA.Pool() method to ./internal/certauthority. This is convenient for at least one test and is simple enough to write and test. Signed-off-by: Matt Moyer --- internal/certauthority/certauthority.go | 7 +++++++ internal/certauthority/certauthority_test.go | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/internal/certauthority/certauthority.go b/internal/certauthority/certauthority.go index 6d3cff84c..87bdd7845 100644 --- a/internal/certauthority/certauthority.go +++ b/internal/certauthority/certauthority.go @@ -136,6 +136,13 @@ func (c *CA) Bundle() []byte { return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes}) } +// Pool returns the current CA signing bundle as a *x509.CertPool. +func (c *CA) Pool() *x509.CertPool { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(c.Bundle()) + return pool +} + // Issue a new server certificate for the given identity and duration. func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time.Duration) (*tls.Certificate, error) { // Choose a random 128 bit serial number. diff --git a/internal/certauthority/certauthority_test.go b/internal/certauthority/certauthority_test.go index 10e74743c..4c1fdf8ed 100644 --- a/internal/certauthority/certauthority_test.go +++ b/internal/certauthority/certauthority_test.go @@ -182,6 +182,16 @@ func TestBundle(t *testing.T) { }) } +func TestPool(t *testing.T) { + t.Run("success", func(t *testing.T) { + ca, err := New(pkix.Name{CommonName: "test"}, 1*time.Hour) + require.NoError(t, err) + + got := ca.Pool() + require.Len(t, got.Subjects(), 1) + }) +} + type errSigner struct { pubkey crypto.PublicKey err error From 545c26e5fe1244ce78e5e5747e0aeec14927edce Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 15:29:54 -0600 Subject: [PATCH 40/57] Refactor browser-related test functions to a `./test/library/browsertest` package. Signed-off-by: Matt Moyer --- test/integration/cli_test.go | 130 ++------------------ test/library/browsertest/browsertest.go | 150 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 122 deletions(-) create mode 100644 test/library/browsertest/browsertest.go diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go index 48965f76c..bb6663e53 100644 --- a/test/integration/cli_test.go +++ b/test/integration/cli_test.go @@ -20,7 +20,6 @@ import ( "testing" "time" - "github.com/sclevine/agouti" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "gopkg.in/square/go-jose.v2" @@ -29,6 +28,7 @@ import ( "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/filesession" "go.pinniped.dev/test/library" + "go.pinniped.dev/test/library/browsertest" ) func TestCLIGetKubeconfig(t *testing.T) { @@ -107,80 +107,14 @@ func runPinnipedCLIGetKubeconfig(t *testing.T, pinnipedExe, token, namespaceName return string(output) } -type loginProviderPatterns struct { - Name string - IssuerPattern *regexp.Regexp - LoginPagePattern *regexp.Regexp - UsernameSelector string - PasswordSelector string - LoginButtonSelector string -} - -func getLoginProvider(t *testing.T) *loginProviderPatterns { - t.Helper() - issuer := library.IntegrationEnv(t).CLITestUpstream.Issuer - for _, p := range []loginProviderPatterns{ - { - Name: "Okta", - IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), - LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), - UsernameSelector: "input#okta-signin-username", - PasswordSelector: "input#okta-signin-password", - LoginButtonSelector: "input#okta-signin-submit", - }, - { - Name: "Dex", - IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`), - LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`), - UsernameSelector: "input#login", - PasswordSelector: "input#password", - LoginButtonSelector: "button#submit-login", - }, - } { - if p.IssuerPattern.MatchString(issuer) { - return &p - } - } - require.Failf(t, "could not find login provider for issuer %q", issuer) - return nil -} - func TestCLILoginOIDC(t *testing.T) { env := library.IntegrationEnv(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - // Find the login CSS selectors for the test issuer, or fail fast. - loginProvider := getLoginProvider(t) - // Start the browser driver. - t.Logf("opening browser driver") - caps := agouti.NewCapabilities() - if env.Proxy != "" { - t.Logf("configuring Chrome to use proxy %q", env.Proxy) - caps = caps.Proxy(agouti.ProxyConfig{ - ProxyType: "manual", - HTTPProxy: env.Proxy, - SSLProxy: env.Proxy, - NoProxy: "127.0.0.1", - }) - } - agoutiDriver := agouti.ChromeDriver( - agouti.Desired(caps), - agouti.ChromeOptions("args", []string{ - "--no-sandbox", - "--ignore-certificate-errors", - "--headless", // Comment out this line to see the tests happen in a visible browser window. - }), - // Uncomment this to see stdout/stderr from chromedriver. - // agouti.Debug, - ) - require.NoError(t, agoutiDriver.Start()) - t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) }) - page, err := agoutiDriver.NewPage(agouti.Browser("chrome")) - require.NoError(t, err) - require.NoError(t, page.Reset()) + page := browsertest.Open(t) // Build pinniped CLI. t.Logf("building CLI binary") @@ -261,28 +195,18 @@ func TestCLILoginOIDC(t *testing.T) { t.Logf("navigating to login page") require.NoError(t, page.Navigate(loginURL)) - // Expect to be redirected to the login page. - t.Logf("waiting for redirect to %s login page", loginProvider.Name) - waitForURL(t, page, loginProvider.LoginPagePattern) + // Expect to be redirected to the upstream provider and log in. + browsertest.LoginToUpstream(t, page, env.CLITestUpstream) - // Wait for the login page to be rendered. - waitForVisibleElements(t, page, loginProvider.UsernameSelector, loginProvider.PasswordSelector, loginProvider.LoginButtonSelector) - - // Fill in the username and password and click "submit". - t.Logf("logging into %s", loginProvider.Name) - require.NoError(t, page.First(loginProvider.UsernameSelector).Fill(env.CLITestUpstream.Username)) - require.NoError(t, page.First(loginProvider.PasswordSelector).Fill(env.CLITestUpstream.Password)) - require.NoError(t, page.First(loginProvider.LoginButtonSelector).Click()) - - // Wait for the login to happen and us be redirected back to a localhost callback. - t.Logf("waiting for redirect to localhost callback") + // Expect to be redirected to the localhost callback. + t.Logf("waiting for redirect to callback") callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(env.CLITestUpstream.CallbackURL) + `\?.+\z`) - waitForURL(t, page, callbackURLPattern) + browsertest.WaitForURL(t, page, callbackURLPattern) // Wait for the "pre" element that gets rendered for a `text/plain` page, and // assert that it contains the success message. t.Logf("verifying success page") - waitForVisibleElements(t, page, "pre") + browsertest.WaitForVisibleElements(t, page, "pre") msg, err := page.First("pre").Text() require.NoError(t, err) require.Equal(t, "you have been logged in and may now close this tab", msg) @@ -360,44 +284,6 @@ func TestCLILoginOIDC(t *testing.T) { require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token) } -func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { - t.Helper() - require.Eventually(t, - func() bool { - for _, sel := range selectors { - vis, err := page.First(sel).Visible() - if !(err == nil && vis) { - return false - } - } - return true - }, - 10*time.Second, - 100*time.Millisecond, - ) -} - -func waitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) { - var lastURL string - require.Eventuallyf(t, - func() bool { - url, err := page.URL() - if err == nil && pat.MatchString(url) { - return true - } - if url != lastURL { - t.Logf("saw URL %s", url) - lastURL = url - } - return false - }, - 10*time.Second, - 100*time.Millisecond, - "expected to browse to %s, but never got there", - pat, - ) -} - func readAndExpectEmpty(r io.Reader) (err error) { var remainder bytes.Buffer _, err = io.Copy(&remainder, r) diff --git a/test/library/browsertest/browsertest.go b/test/library/browsertest/browsertest.go new file mode 100644 index 000000000..9a50e7808 --- /dev/null +++ b/test/library/browsertest/browsertest.go @@ -0,0 +1,150 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package browsertest provides integration test helpers for our browser-based tests. +package browsertest + +import ( + "regexp" + "testing" + "time" + + "github.com/sclevine/agouti" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/test/library" +) + +// Open a webdriver-driven browser and returns an *agouti.Page to control it. The browser will be automatically +// closed at the end of the current test. It is configured for test purposes with the correct HTTP proxy and +// in a mode that ignore certificate errors. +func Open(t *testing.T) *agouti.Page { + t.Logf("opening browser driver") + env := library.IntegrationEnv(t) + caps := agouti.NewCapabilities() + if env.Proxy != "" { + t.Logf("configuring Chrome to use proxy %q", env.Proxy) + caps = caps.Proxy(agouti.ProxyConfig{ + ProxyType: "manual", + HTTPProxy: env.Proxy, + SSLProxy: env.Proxy, + NoProxy: "127.0.0.1", + }) + } + agoutiDriver := agouti.ChromeDriver( + agouti.Desired(caps), + agouti.ChromeOptions("args", []string{ + "--no-sandbox", + "--ignore-certificate-errors", + "--headless", // Comment out this line to see the tests happen in a visible browser window. + }), + // Uncomment this to see stdout/stderr from chromedriver. + // agouti.Debug, + ) + require.NoError(t, agoutiDriver.Start()) + t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) }) + page, err := agoutiDriver.NewPage(agouti.Browser("chrome")) + require.NoError(t, err) + require.NoError(t, page.Reset()) + return page +} + +// WaitForVisibleElements expects the page to contain all the the elements specified by the selectors. It waits for this +// to occur and times out, failing the test, if they never appear. +func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { + t.Helper() + require.Eventually(t, + func() bool { + for _, sel := range selectors { + vis, err := page.First(sel).Visible() + if !(err == nil && vis) { + return false + } + } + return true + }, + 10*time.Second, + 100*time.Millisecond, + ) +} + +// WaitForURL expects the page to eventually navigate to a URL matching the specified pattern. It waits for this +// to occur and times out, failing the test, if it never does. +func WaitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) { + var lastURL string + require.Eventuallyf(t, + func() bool { + url, err := page.URL() + if err == nil && pat.MatchString(url) { + return true + } + if url != lastURL { + t.Logf("saw URL %s", url) + lastURL = url + } + return false + }, + 10*time.Second, + 100*time.Millisecond, + "expected to browse to %s, but never got there", + pat, + ) +} + +// LoginToUpstream expects the page to be redirected to one of several known upstream IDPs. +// It knows how to enter the test username/password and submit the upstream login form. +func LoginToUpstream(t *testing.T, page *agouti.Page, upstream library.TestOIDCUpstream) { + t.Helper() + + type config struct { + Name string + IssuerPattern *regexp.Regexp + LoginPagePattern *regexp.Regexp + UsernameSelector string + PasswordSelector string + LoginButtonSelector string + } + + // Lookup the provider by matching on the issuer URL. + var cfg *config + for _, p := range []*config{ + { + Name: "Okta", + IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), + LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), + UsernameSelector: "input#okta-signin-username", + PasswordSelector: "input#okta-signin-password", + LoginButtonSelector: "input#okta-signin-submit", + }, + { + Name: "Dex", + IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`), + LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`), + UsernameSelector: "input#login", + PasswordSelector: "input#password", + LoginButtonSelector: "button#submit-login", + }, + } { + if p.IssuerPattern.MatchString(upstream.Issuer) { + cfg = p + break + } + } + if cfg == nil { + require.Failf(t, "could not find login provider for issuer %q", upstream.Issuer) + return + } + + // Expect to be redirected to the login page. + t.Logf("waiting for redirect to %s login page", cfg.Name) + WaitForURL(t, page, cfg.LoginPagePattern) + + // Wait for the login page to be rendered. + WaitForVisibleElements(t, page, cfg.UsernameSelector, cfg.PasswordSelector, cfg.LoginButtonSelector) + + // Fill in the username and password and click "submit". + t.Logf("logging into %s", cfg.Name) + require.NoError(t, page.First(cfg.UsernameSelector).Fill(upstream.Username)) + require.NoError(t, page.First(cfg.PasswordSelector).Fill(upstream.Password)) + require.NoError(t, page.First(cfg.LoginButtonSelector).Click()) +} From 273ac62ec2660a78fb30ae2dd15b916d804863fc Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 15:32:54 -0600 Subject: [PATCH 41/57] Extend the test client helpers in ./test/library/client.go. This adds a few new "create test object" helpers and extends `CreateTestOIDCProvider()` to optionally wait for the created OIDCProvider to enter some expected status condition. Signed-off-by: Matt Moyer --- test/integration/supervisor_discovery_test.go | 14 +- test/integration/supervisor_keys_test.go | 2 +- test/integration/supervisor_login_test.go | 4 +- test/integration/supervisor_upstream_test.go | 80 +----------- test/library/client.go | 121 +++++++++++++++--- 5 files changed, 114 insertions(+), 107 deletions(-) diff --git a/test/integration/supervisor_discovery_test.go b/test/integration/supervisor_discovery_test.go index 32c4c0046..b029cc174 100644 --- a/test/integration/supervisor_discovery_test.go +++ b/test/integration/supervisor_discovery_test.go @@ -111,7 +111,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) { // When the same issuer is added twice, both issuers are marked as duplicates, and neither provider is serving. config6Duplicate1, _ := requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(ctx, t, scheme, addr, caBundle, issuer6, client) - config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "") + config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "", "") requireStatus(t, client, ns, config6Duplicate1.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireStatus(t, client, ns, config6Duplicate2.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, issuer6) @@ -136,7 +136,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) { } // When we create a provider with an invalid issuer, the status is set to invalid. - badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "") + badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "", "") requireStatus(t, client, ns, badConfig.Name, v1alpha1.InvalidOIDCProviderStatusCondition) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, badIssuer) requireDeletingOIDCProviderCausesDiscoveryEndpointsToDisappear(t, badConfig, client, ns, scheme, addr, caBundle, badIssuer) @@ -162,7 +162,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) { certSecretName1 := "integration-test-cert-1" // Create an OIDCProvider with a spec.tls.secretName. - oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1) + oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1, "") requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // The spec.tls.secretName Secret does not exist, so the endpoints should fail with TLS errors. @@ -198,7 +198,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) { certSecretName2 := "integration-test-cert-2" // Create an OIDCProvider with a spec.tls.secretName. - oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2) + oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2, "") requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // Create the Secret. @@ -241,7 +241,7 @@ func TestSupervisorTLSTerminationWithDefaultCerts(t *testing.T) { issuerUsingHostname := fmt.Sprintf("%s://%s/issuer1", scheme, address) // Create an OIDCProvider without a spec.tls.secretName. - oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "") + oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "", "") requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // There is no default TLS cert and the spec.tls.secretName was not set, so the endpoints should fail with TLS errors. @@ -255,7 +255,7 @@ func TestSupervisorTLSTerminationWithDefaultCerts(t *testing.T) { // Create an OIDCProvider with a spec.tls.secretName. certSecretName := "integration-test-cert-1" - oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName) + oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName, "") requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // Create the Secret. @@ -428,7 +428,7 @@ func requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear( client pinnipedclientset.Interface, ) (*v1alpha1.OIDCProvider, *ExpectedJWKSResponseFormat) { t.Helper() - newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "") + newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "", "") jwksResult := requireDiscoveryEndpointsAreWorking(t, supervisorScheme, supervisorAddress, supervisorCABundle, issuerName, nil) requireStatus(t, client, newOIDCProvider.Namespace, newOIDCProvider.Name, v1alpha1.SuccessOIDCProviderStatusCondition) return newOIDCProvider, jwksResult diff --git a/test/integration/supervisor_keys_test.go b/test/integration/supervisor_keys_test.go index 17e6a5803..d59c713e1 100644 --- a/test/integration/supervisor_keys_test.go +++ b/test/integration/supervisor_keys_test.go @@ -27,7 +27,7 @@ func TestSupervisorOIDCKeys(t *testing.T) { defer cancel() // Create our OPC under test. - opc := library.CreateTestOIDCProvider(ctx, t, "", "") + opc := library.CreateTestOIDCProvider(ctx, t, "", "", "") // Ensure a secret is created with the OPC's JWKS. var updatedOPC *configv1alpha1.OIDCProvider diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 0ce937fbb..c3ff63ecb 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -92,10 +92,10 @@ func TestSupervisorLogin(t *testing.T) { CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), }, Client: idpv1alpha1.OIDCClient{ - SecretName: makeTestClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, + SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, }, } - upstream := makeTestUpstream(t, spec, idpv1alpha1.PhaseReady) + upstream := library.CreateTestUpstreamOIDCProvider(t, spec, idpv1alpha1.PhaseReady) // Make request to authorize endpoint - should pass, since we now have an upstream. req, err = http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) diff --git a/test/integration/supervisor_upstream_test.go b/test/integration/supervisor_upstream_test.go index e38f2b17e..dd3fa5289 100644 --- a/test/integration/supervisor_upstream_test.go +++ b/test/integration/supervisor_upstream_test.go @@ -4,13 +4,10 @@ package integration import ( - "context" "encoding/base64" "testing" - "time" "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" @@ -28,7 +25,7 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) { SecretName: "does-not-exist", }, } - upstream := makeTestUpstream(t, spec, v1alpha1.PhaseError) + upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseError) expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ { Type: "ClientCredentialsValid", @@ -56,10 +53,10 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) { AdditionalScopes: []string{"email", "profile"}, }, Client: v1alpha1.OIDCClient{ - SecretName: makeTestClientCredsSecret(t, "test-client-id", "test-client-secret").Name, + SecretName: library.CreateClientCredsSecret(t, "test-client-id", "test-client-secret").Name, }, } - upstream := makeTestUpstream(t, spec, v1alpha1.PhaseReady) + upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseReady) expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ { Type: "ClientCredentialsValid", @@ -87,74 +84,3 @@ func expectUpstreamConditions(t *testing.T, upstream *v1alpha1.UpstreamOIDCProvi } require.ElementsMatch(t, expected, normalized) } - -func makeTestClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret { - t.Helper() - env := library.IntegrationEnv(t) - client := library.NewClientset(t) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - created, err := client.CoreV1().Secrets(env.SupervisorNamespace).Create(ctx, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: env.SupervisorNamespace, - GenerateName: "test-client-creds-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Type: "secrets.pinniped.dev/oidc-client", - StringData: map[string]string{ - "clientID": clientID, - "clientSecret": clientSecret, - }, - }, metav1.CreateOptions{}) - require.NoError(t, err) - t.Cleanup(func() { - err := client.CoreV1().Secrets(env.SupervisorNamespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{}) - require.NoError(t, err) - }) - t.Logf("created test client credentials Secret %s", created.Name) - return created -} - -func makeTestUpstream(t *testing.T, spec v1alpha1.UpstreamOIDCProviderSpec, expectedPhase v1alpha1.UpstreamOIDCProviderPhase) *v1alpha1.UpstreamOIDCProvider { - t.Helper() - env := library.IntegrationEnv(t) - client := library.NewSupervisorClientset(t) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - // Create the UpstreamOIDCProvider using GenerateName to get a random name. - created, err := client.IDPV1alpha1(). - UpstreamOIDCProviders(env.SupervisorNamespace). - Create(ctx, &v1alpha1.UpstreamOIDCProvider{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: env.SupervisorNamespace, - GenerateName: "test-upstream-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Spec: spec, - }, metav1.CreateOptions{}) - require.NoError(t, err) - - // Always clean this up after this point. - t.Cleanup(func() { - err := client.IDPV1alpha1(). - UpstreamOIDCProviders(env.SupervisorNamespace). - Delete(context.Background(), created.Name, metav1.DeleteOptions{}) - require.NoError(t, err) - }) - t.Logf("created test UpstreamOIDCProvider %s", created.Name) - - // Wait for the UpstreamOIDCProvider to enter the expected phase (or time out). - var result *v1alpha1.UpstreamOIDCProvider - require.Eventuallyf(t, func() bool { - var err error - result, err = client.IDPV1alpha1(). - UpstreamOIDCProviders(created.Namespace).Get(ctx, created.Name, metav1.GetOptions{}) - require.NoError(t, err) - return result.Status.Phase == expectedPhase - }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase) - return result -} diff --git a/test/library/client.go b/test/library/client.go index f5aba3a6f..d95f04263 100644 --- a/test/library/client.go +++ b/test/library/client.go @@ -25,6 +25,7 @@ import ( auth1alpha1 "go.pinniped.dev/generated/1.19/apis/concierge/authentication/v1alpha1" configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1" + idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" conciergeclientset "go.pinniped.dev/generated/1.19/client/concierge/clientset/versioned" supervisorclientset "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned" @@ -140,12 +141,8 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty defer cancel() webhook, err := webhooks.Create(createContext, &auth1alpha1.WebhookAuthenticator{ - ObjectMeta: metav1.ObjectMeta{ - GenerateName: "test-webhook-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Spec: testEnv.TestWebhook, + ObjectMeta: testObjectMeta(t, "webhook"), + Spec: testEnv.TestWebhook, }, metav1.CreateOptions{}) require.NoError(t, err, "could not create test WebhookAuthenticator") t.Logf("created test WebhookAuthenticator %s/%s", webhook.Namespace, webhook.Name) @@ -172,7 +169,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty // // If the provided issuer is not the empty string, then it will be used for the // OIDCProvider.Spec.Issuer field. Else, a random issuer will be generated. -func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecretName string) *configv1alpha1.OIDCProvider { +func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer string, certSecretName string, expectStatus configv1alpha1.OIDCProviderStatusCondition) *configv1alpha1.OIDCProvider { t.Helper() testEnv := IntegrationEnv(t) @@ -180,18 +177,12 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre defer cancel() if issuer == "" { - var err error - issuer, err = randomIssuer() - require.NoError(t, err) + issuer = randomIssuer(t) } opcs := NewSupervisorClientset(t).ConfigV1alpha1().OIDCProviders(testEnv.SupervisorNamespace) opc, err := opcs.Create(createContext, &configv1alpha1.OIDCProvider{ - ObjectMeta: metav1.ObjectMeta{ - GenerateName: "test-oidc-provider-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, + ObjectMeta: testObjectMeta(t, "oidc-provider"), Spec: configv1alpha1.OIDCProviderSpec{ Issuer: issuer, TLS: &configv1alpha1.OIDCProviderTLSSpec{SecretName: certSecretName}, @@ -213,13 +204,103 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre } }) + // If we're not expecting any particular status, just return the new OIDCProvider immediately. + if expectStatus == "" { + return opc + } + + // Wait for the OIDCProvider to enter the expected phase (or time out). + var result *configv1alpha1.OIDCProvider + require.Eventuallyf(t, func() bool { + var err error + result, err = opcs.Get(ctx, opc.Name, metav1.GetOptions{}) + require.NoError(t, err) + return result.Status.Status == expectStatus + }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectStatus) + return opc } -func randomIssuer() (string, error) { +func randomIssuer(t *testing.T) string { var buf [8]byte - if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil { - return "", fmt.Errorf("could not generate random state: %w", err) - } - return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])), nil + _, err := io.ReadFull(rand.Reader, buf[:]) + require.NoError(t, err) + return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])) +} + +func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType string, stringData map[string]string) *corev1.Secret { + t.Helper() + client := NewClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + created, err := client.CoreV1().Secrets(namespace).Create(ctx, &corev1.Secret{ + ObjectMeta: testObjectMeta(t, baseName), + Type: corev1.SecretType(secretType), + StringData: stringData, + }, metav1.CreateOptions{}) + require.NoError(t, err) + + t.Cleanup(func() { + err := client.CoreV1().Secrets(namespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{}) + require.NoError(t, err) + }) + t.Logf("created test Secret %s", created.Name) + return created +} + +func CreateClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret { + t.Helper() + env := IntegrationEnv(t) + return CreateTestSecret(t, + env.SupervisorNamespace, + "test-client-creds-", + "secrets.pinniped.dev/oidc-client", + map[string]string{ + "clientID": clientID, + "clientSecret": clientSecret, + }, + ) +} + +func CreateTestUpstreamOIDCProvider(t *testing.T, spec idpv1alpha1.UpstreamOIDCProviderSpec, expectedPhase idpv1alpha1.UpstreamOIDCProviderPhase) *idpv1alpha1.UpstreamOIDCProvider { + t.Helper() + env := IntegrationEnv(t) + client := NewSupervisorClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + // Create the UpstreamOIDCProvider using GenerateName to get a random name. + upstreams := client.IDPV1alpha1().UpstreamOIDCProviders(env.SupervisorNamespace) + + created, err := upstreams.Create(ctx, &idpv1alpha1.UpstreamOIDCProvider{ + ObjectMeta: testObjectMeta(t, "upstream"), + Spec: spec, + }, metav1.CreateOptions{}) + require.NoError(t, err) + + // Always clean this up after this point. + t.Cleanup(func() { + err := upstreams.Delete(context.Background(), created.Name, metav1.DeleteOptions{}) + require.NoError(t, err) + }) + t.Logf("created test UpstreamOIDCProvider %s", created.Name) + + // Wait for the UpstreamOIDCProvider to enter the expected phase (or time out). + var result *idpv1alpha1.UpstreamOIDCProvider + require.Eventuallyf(t, func() bool { + var err error + result, err = upstreams.Get(ctx, created.Name, metav1.GetOptions{}) + require.NoError(t, err) + return result.Status.Phase == expectedPhase + }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase) + return result +} + +func testObjectMeta(t *testing.T, baseName string) metav1.ObjectMeta { + return metav1.ObjectMeta{ + GenerateName: fmt.Sprintf("test-%s-", baseName), + Labels: map[string]string{"pinniped.dev/test": ""}, + Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, + } } From 0ccf14801e1e592db45969bcb997025e0838b43a Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 15:43:17 -0600 Subject: [PATCH 42/57] Expose the MaskTokens function so other test code can use it. This is just a small helper to make test output more readable. Signed-off-by: Matt Moyer --- test/library/iotest.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/library/iotest.go b/test/library/iotest.go index dcb0e6959..6e2f1e58b 100644 --- a/test/library/iotest.go +++ b/test/library/iotest.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "regexp" + "strings" "testing" ) @@ -26,18 +27,22 @@ func (l *testlogReader) Read(p []byte) (n int, err error) { l.t.Helper() n, err = l.r.Read(p) if err != nil { - l.t.Logf("%s > %q: %v", l.name, maskTokens(p[0:n]), err) + l.t.Logf("%s > %q: %v", l.name, MaskTokens(string(p[0:n])), err) } else { - l.t.Logf("%s > %q", l.name, maskTokens(p[0:n])) + l.t.Logf("%s > %q", l.name, MaskTokens(string(p[0:n]))) } return } -//nolint: gochecknoglobals -var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) - -func maskTokens(in []byte) string { - return tokenLike.ReplaceAllStringFunc(string(in), func(t string) string { +// MaskTokens makes a best-effort attempt to mask out things that look like secret tokens in test output. +// The goal is more to have readable test output than for any security reason. +func MaskTokens(in string) string { + var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) + return tokenLike.ReplaceAllStringFunc(in, func(t string) string { + // This is a silly heuristic, but things with multiple dots are more likely hostnames that we don't want masked. + if strings.Count(t, ".") >= 4 { + return t + } return fmt.Sprintf("[...%d bytes...]", len(t)) }) } From f40144e1a9c1878ae192e06a8f7a16c95b542289 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 15:50:42 -0600 Subject: [PATCH 43/57] Update TestSupervisorLogin to test the callback flow using a browser. Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 314 +++++++++------------- 1 file changed, 129 insertions(+), 185 deletions(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index c3ff63ecb..85d2538ef 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -6,12 +6,12 @@ package integration import ( "context" "crypto/tls" - "crypto/x509" + "crypto/x509/pkix" "encoding/base64" - "fmt" "net/http" + "net/http/httptest" "net/url" - "path" + "regexp" "strings" "testing" "time" @@ -20,216 +20,160 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1" idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" + "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" "go.pinniped.dev/test/library" + "go.pinniped.dev/test/library/browsertest" ) func TestSupervisorLogin(t *testing.T) { env := library.IntegrationEnv(t) - client := library.NewSupervisorClientset(t) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - tests := []struct { - Scheme string - Address string - CABundle string - }{ - {Scheme: "http", Address: env.SupervisorHTTPAddress}, - {Scheme: "https", Address: env.SupervisorHTTPSIngressAddress, CABundle: env.SupervisorHTTPSIngressCABundle}, - } + // Infer the downstream issuer URL from the callback associated with the upstream test client registration. + issuerURL, err := url.Parse(env.SupervisorTestUpstream.CallbackURL) + require.NoError(t, err) + require.True(t, strings.HasSuffix(issuerURL.Path, "/callback")) + issuerURL.Path = strings.TrimSuffix(issuerURL.Path, "/callback") + t.Logf("testing with downstream issuer URL %s", issuerURL.String()) - for _, test := range tests { - scheme := test.Scheme - addr := test.Address - caBundle := test.CABundle - - if addr == "" { - // Both cases are not required, so when one is empty skip it. - continue - } - - // Create downstream OIDC provider (i.e., update supervisor with OIDC provider). - path := getDownstreamIssuerPathFromUpstreamRedirectURI(t, env.SupervisorTestUpstream.CallbackURL) - issuer := fmt.Sprintf("https://%s%s", addr, path) - _, _ = requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear( - ctx, - t, - scheme, - addr, - caBundle, - issuer, - client, - ) - - // Create HTTP client. - httpClient := newHTTPClient(t, caBundle, nil) - httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { - // Don't follow any redirects right now, since we simply want to validate that our auth endpoint - // redirects us. - return http.ErrUseLastResponse - } - - // Declare the downstream auth endpoint url we will use. - downstreamAuthURL := makeDownstreamAuthURL(t, scheme, addr, path) - - // Make request to auth endpoint - should fail, since we have no upstreams. - req, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) - require.NoError(t, err) - rsp, err := httpClient.Do(req) - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusUnprocessableEntity, rsp.StatusCode) - - // Create upstream OIDC provider. - spec := idpv1alpha1.UpstreamOIDCProviderSpec{ - Issuer: env.SupervisorTestUpstream.Issuer, - TLS: &idpv1alpha1.TLSSpec{ - CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), - }, - Client: idpv1alpha1.OIDCClient{ - SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, - }, - } - upstream := library.CreateTestUpstreamOIDCProvider(t, spec, idpv1alpha1.PhaseReady) - - // Make request to authorize endpoint - should pass, since we now have an upstream. - req, err = http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) - require.NoError(t, err) - rsp, err = httpClient.Do(req) - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusFound, rsp.StatusCode) - requireValidRedirectLocation( - ctx, - t, - upstream.Spec.Issuer, - env.SupervisorTestUpstream.ClientID, - env.SupervisorTestUpstream.CallbackURL, - rsp.Header.Get("Location"), - ) - } -} - -func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedirectURI string) string { - // We need to construct the downstream issuer path from the upstream redirect URI since the two - // are related, and the upstream redirect URI is supplied via a static test environment - // variable. The upstream redirect URI should be something like - // https://supervisor.com/some/supervisor/path/callback - // and therefore the downstream issuer should be something like - // https://supervisor.com/some/supervisor/path - // since the /callback endpoint is placed at the root of the downstream issuer path. - upstreamRedirectURL, err := url.Parse(upstreamRedirectURI) + // Generate a CA bundle with which to serve this provider. + t.Logf("generating test CA") + ca, err := certauthority.New(pkix.Name{CommonName: "Downstream Test CA"}, 1*time.Hour) require.NoError(t, err) - redirectURIPathWithoutLastSegment, lastUpstreamRedirectURIPathSegment := path.Split(upstreamRedirectURL.Path) - require.Equalf( - t, - "callback", - lastUpstreamRedirectURIPathSegment, - "expected upstream redirect URI (%q) to follow supervisor callback path conventions (i.e., end in /callback)", - upstreamRedirectURI, + // Create an HTTP client that can reach the downstream discovery endpoint using the CA certs. + httpClient := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, + Proxy: func(req *http.Request) (*url.URL, error) { + if env.Proxy == "" { + return nil, nil + } + return url.Parse(env.Proxy) + }, + }} + + // Use the CA to issue a TLS server cert. + t.Logf("issuing test certificate") + tlsCert, err := ca.Issue( + pkix.Name{CommonName: issuerURL.Hostname()}, + []string{issuerURL.Hostname()}, + nil, + 1*time.Hour, + ) + require.NoError(t, err) + certPEM, keyPEM, err := certauthority.ToPEM(tlsCert) + require.NoError(t, err) + + // Write the serving cert to a secret. + certSecret := library.CreateTestSecret(t, + env.SupervisorNamespace, + "oidc-provider-tls", + "kubernetes.io/tls", + map[string]string{"tls.crt": string(certPEM), "tls.key": string(keyPEM)}, ) - if strings.HasSuffix(redirectURIPathWithoutLastSegment, "/") { - redirectURIPathWithoutLastSegment = redirectURIPathWithoutLastSegment[:len(redirectURIPathWithoutLastSegment)-1] - } + // Create the downstream OIDCProvider and expect it to go into the success status condition. + downstream := library.CreateTestOIDCProvider(ctx, t, + issuerURL.String(), + certSecret.Name, + configv1alpha1.SuccessOIDCProviderStatusCondition, + ) - return redirectURIPathWithoutLastSegment -} + // Create upstream OIDC provider and wait for it to become ready. + library.CreateTestUpstreamOIDCProvider(t, idpv1alpha1.UpstreamOIDCProviderSpec{ + Issuer: env.SupervisorTestUpstream.Issuer, + TLS: &idpv1alpha1.TLSSpec{ + CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), + }, + AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{ + AdditionalScopes: []string{"email", "profile"}, + }, + Client: idpv1alpha1.OIDCClient{ + SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, + }, + }, idpv1alpha1.PhaseReady) -func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string { - t.Helper() + // Perform OIDC discovery for our downstream. + discovery, err := oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer) + require.NoError(t, err) + + // Start a callback server on localhost. + localCallbackServer := startLocalCallbackServer(t) + + // Form the OAuth2 configuration corresponding to our CLI client. downstreamOAuth2Config := oauth2.Config{ // This is the hardcoded public client that the supervisor supports. - ClientID: "pinniped-cli", - Endpoint: oauth2.Endpoint{ - AuthURL: fmt.Sprintf("%s://%s%s/oauth2/authorize", scheme, addr, path), - }, - // This is the hardcoded downstream redirect URI that the supervisor supports. - RedirectURL: "http://127.0.0.1/callback", + ClientID: "pinniped-cli", + Endpoint: discovery.Endpoint(), + RedirectURL: localCallbackServer.URL, Scopes: []string{"openid"}, } - state, nonce, pkce := generateAuthRequestParams(t) - return downstreamOAuth2Config.AuthCodeURL( - state.String(), - nonce.Param(), - pkce.Challenge(), - pkce.Method(), + + // Build a valid downstream authorize URL for the supervisor. + stateParam, err := state.Generate() + require.NoError(t, err) + nonceParam, err := nonce.Generate() + require.NoError(t, err) + pkceParam, err := pkce.Generate() + require.NoError(t, err) + downstreamAuthorizeURL := downstreamOAuth2Config.AuthCodeURL( + stateParam.String(), + nonceParam.Param(), + pkceParam.Challenge(), + pkceParam.Method(), ) + + // Open the web browser and navigate to the downstream authorize URL. + page := browsertest.Open(t) + t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL)) + require.NoError(t, page.Navigate(downstreamAuthorizeURL)) + + // Expect to be redirected to the upstream provider and log in. + browsertest.LoginToUpstream(t, page, env.SupervisorTestUpstream) + + // Wait for the login to happen and us be redirected back to a localhost callback. + t.Logf("waiting for redirect to callback") + callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(localCallbackServer.URL) + `\?.+\z`) + browsertest.WaitForURL(t, page, callbackURLPattern) + + // Expect that our callback handler was invoked. + callback := localCallbackServer.waitForCallback(10 * time.Second) + t.Logf("got callback request: %s", library.MaskTokens(callback.URL.String())) + require.Equal(t, stateParam.String(), callback.URL.Query().Get("state")) + require.Equal(t, "openid", callback.URL.Query().Get("scope")) + require.NotEmpty(t, callback.URL.Query().Get("code")) } -func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Code) { - t.Helper() - state, err := state.Generate() - require.NoError(t, err) - nonce, err := nonce.Generate() - require.NoError(t, err) - pkce, err := pkce.Generate() - require.NoError(t, err) - return state, nonce, pkce +func startLocalCallbackServer(t *testing.T) *localCallbackServer { + // Handle the callback by sending the *http.Request object back through a channel. + callbacks := make(chan *http.Request, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callbacks <- r + })) + server.URL = server.URL + "/callback" + t.Cleanup(server.Close) + t.Cleanup(func() { close(callbacks) }) + return &localCallbackServer{Server: server, t: t, callbacks: callbacks} } -func requireValidRedirectLocation( - ctx context.Context, - t *testing.T, - issuer, clientID, redirectURI, actualLocation string, -) { - t.Helper() - env := library.IntegrationEnv(t) +type localCallbackServer struct { + *httptest.Server + t *testing.T + callbacks <-chan *http.Request +} - // Do OIDC discovery on our test issuer to get auth endpoint. - transport := http.Transport{} - if env.Proxy != "" { - transport.Proxy = func(_ *http.Request) (*url.URL, error) { - return url.Parse(env.Proxy) - } +func (s *localCallbackServer) waitForCallback(timeout time.Duration) *http.Request { + select { + case callback := <-s.callbacks: + return callback + case <-time.After(timeout): + require.Fail(s.t, "timed out waiting for callback request") + return nil } - if env.SupervisorTestUpstream.CABundle != "" { - transport.TLSClientConfig = &tls.Config{RootCAs: x509.NewCertPool()} - transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(env.SupervisorTestUpstream.CABundle)) - } - - ctx = oidc.ClientContext(ctx, &http.Client{Transport: &transport}) - upstreamProvider, err := oidc.NewProvider(ctx, issuer) - require.NoError(t, err) - - // Parse expected upstream auth URL. - expectedLocationURL, err := url.Parse( - (&oauth2.Config{ - ClientID: clientID, - Endpoint: upstreamProvider.Endpoint(), - RedirectURL: redirectURI, - Scopes: []string{"openid"}, - }).AuthCodeURL("", oauth2.AccessTypeOffline), - ) - require.NoError(t, err) - - // Parse actual upstream auth URL. - actualLocationURL, err := url.Parse(actualLocation) - require.NoError(t, err) - - // First make some assertions on the query values. Note that we will not be able to know what - // certain query values are since they may be random (e.g., state, pkce, nonce). - expectedLocationQuery := expectedLocationURL.Query() - actualLocationQuery := actualLocationURL.Query() - require.NotEmpty(t, actualLocationQuery.Get("state")) - actualLocationQuery.Del("state") - require.NotEmpty(t, actualLocationQuery.Get("code_challenge")) - actualLocationQuery.Del("code_challenge") - require.NotEmpty(t, actualLocationQuery.Get("code_challenge_method")) - actualLocationQuery.Del("code_challenge_method") - require.NotEmpty(t, actualLocationQuery.Get("nonce")) - actualLocationQuery.Del("nonce") - require.Equal(t, expectedLocationQuery, actualLocationQuery) - - // Zero-out query values, since we made specific assertions about those above, and assert that the - // URL's are equal otherwise. - expectedLocationURL.RawQuery = "" - actualLocationURL.RawQuery = "" - require.Equal(t, expectedLocationURL, actualLocationURL) } From ae9bdc1d61bb04be32f9110022b95127e35966bc Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 16:11:22 -0600 Subject: [PATCH 44/57] Fix a lint warning by simplifying this append operation. Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 85d2538ef..00f68dd11 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -156,7 +156,7 @@ func startLocalCallbackServer(t *testing.T) *localCallbackServer { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callbacks <- r })) - server.URL = server.URL + "/callback" + server.URL += "/callback" t.Cleanup(server.Close) t.Cleanup(func() { close(callbacks) }) return &localCallbackServer{Server: server, t: t, callbacks: callbacks} From c32013228970133691a71575a61ce4240c638daf Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 2 Dec 2020 14:10:41 -0800 Subject: [PATCH 45/57] Back-fill some more unit tests on authorizationcode_test.go --- .../authorizationcode_test.go | 56 +++++++++++++++++-- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/internal/fositestorage/authorizationcode/authorizationcode_test.go b/internal/fositestorage/authorizationcode/authorizationcode_test.go index 38f3e1a95..616eb2de1 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode_test.go +++ b/internal/fositestorage/authorizationcode/authorizationcode_test.go @@ -8,12 +8,16 @@ import ( "crypto/ed25519" "crypto/x509" "encoding/json" + "fmt" "math/rand" "net/url" "strings" "testing" "time" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + fuzz "github.com/google/gofuzz" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" @@ -24,7 +28,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/kubernetes/fake" - coretesting "k8s.io/client-go/testing" + kubetesting "k8s.io/client-go/testing" "go.pinniped.dev/internal/fositestorage" ) @@ -39,8 +43,8 @@ func TestAuthorizationCodeStorage(t *testing.T) { Resource: "secrets", } - wantActions := []coretesting.Action{ - coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + wantActions := []kubetesting.Action{ + kubetesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", ResourceVersion: "", @@ -54,9 +58,9 @@ func TestAuthorizationCodeStorage(t *testing.T) { }, Type: "storage.pinniped.dev/authcode", }), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), - coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ + kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + kubetesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", ResourceVersion: "", @@ -121,6 +125,11 @@ func TestAuthorizationCodeStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, wantActions, client.Actions()) + + // Doing a Get on an invalidated session should still return the session, but also return an error. + invalidatedRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "authorization code session for fancy-signature has already been used: Authorization code has ben invalidated") + require.Equal(t, "abcd-1", invalidatedRequest.GetID()) } func TestGetNotFound(t *testing.T) { @@ -134,6 +143,41 @@ func TestGetNotFound(t *testing.T) { require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) } +func TestInvalidateWhenNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + notFoundErr := storage.InvalidateAuthorizeCodeSession(ctx, "non-existent-signature") + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestInvalidateWhenConflictOnUpdateHappens(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + client.PrependReactor("update", "secrets", func(_ kubetesting.Action) (bool, runtime.Object, error) { + return true, nil, apierrors.NewConflict(schema.GroupResource{ + Group: "", + Resource: "secrets", + }, "some-secret-name", fmt.Errorf("there was a conflict")) + }) + + request := &fosite.Request{ + ID: "some-request-id", + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + } + err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) + require.NoError(t, err) + err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") + require.EqualError(t, err, `The request could not be completed due to concurrent access: failed to update authcode for signature fancy-signature at resource version : Operation cannot be fulfilled on secrets "some-secret-name": there was a conflict`) +} + func TestWrongVersion(t *testing.T) { ctx := context.Background() client := fake.NewSimpleClientset() From 6ed9107df02ca409b7bc61631b247d9050b3ed23 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 2 Dec 2020 14:20:03 -0800 Subject: [PATCH 46/57] Remove a couple of todos that will be resolved in Slack conversations --- internal/crud/crud.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/crud/crud.go b/internal/crud/crud.go index 82d320826..9e15581d7 100644 --- a/internal/crud/crud.go +++ b/internal/crud/crud.go @@ -30,7 +30,7 @@ const ( ErrSecretTypeMismatch = constable.Error("secret storage data has incorrect type") ErrSecretLabelMismatch = constable.Error("secret storage data has incorrect label") - ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") // TODO do we need this? + ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") ) type Storage interface { @@ -139,7 +139,7 @@ func (s *secretsStorage) toSecret(signature, resourceVersion string, data JSON) Labels: map[string]string{ secretLabelKey: s.resource, // make it easier to find this stuff via kubectl }, - OwnerReferences: nil, // TODO we should set this to make sure stuff gets clean up + OwnerReferences: nil, }, Data: map[string][]byte{ secretDataKey: buf, From 879525faac9aed4be45d6175028d604efa056eba Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 17:20:24 -0600 Subject: [PATCH 47/57] Clean up the browsertest package a bit. Signed-off-by: Matt Moyer --- test/library/browsertest/browsertest.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/library/browsertest/browsertest.go b/test/library/browsertest/browsertest.go index 9a50e7808..d7da81420 100644 --- a/test/library/browsertest/browsertest.go +++ b/test/library/browsertest/browsertest.go @@ -15,6 +15,11 @@ import ( "go.pinniped.dev/test/library" ) +const ( + operationTimeout = 10 * time.Second + operationPollingInterval = 100 * time.Millisecond +) + // Open a webdriver-driven browser and returns an *agouti.Page to control it. The browser will be automatically // closed at the end of the current test. It is configured for test purposes with the correct HTTP proxy and // in a mode that ignore certificate errors. @@ -53,7 +58,8 @@ func Open(t *testing.T) *agouti.Page { // to occur and times out, failing the test, if they never appear. func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { t.Helper() - require.Eventually(t, + + require.Eventuallyf(t, func() bool { for _, sel := range selectors { vis, err := page.First(sel).Visible() @@ -63,8 +69,10 @@ func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string } return true }, - 10*time.Second, - 100*time.Millisecond, + operationTimeout, + operationPollingInterval, + "expected to have a page with selectors %v, but it never loaded", + selectors, ) } @@ -84,8 +92,8 @@ func WaitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) { } return false }, - 10*time.Second, - 100*time.Millisecond, + operationTimeout, + operationPollingInterval, "expected to browse to %s, but never got there", pat, ) From 64ef53402dc1a0488be41e462686b9c70e314720 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Wed, 2 Dec 2020 18:07:52 -0600 Subject: [PATCH 48/57] In TestSupervisorLogin, wrap the discovery request in an `Eventually()`. Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 00f68dd11..92e7c086d 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/coreos/go-oidc" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -100,7 +101,11 @@ func TestSupervisorLogin(t *testing.T) { }, idpv1alpha1.PhaseReady) // Perform OIDC discovery for our downstream. - discovery, err := oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer) + var discovery *oidc.Provider + assert.Eventually(t, func() bool { + discovery, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer) + return err == nil + }, 60*time.Second, 1*time.Second) require.NoError(t, err) // Start a callback server on localhost. From 95093ab0af59ca5b14038ca9069b3324863b1002 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 2 Dec 2020 17:39:45 -0800 Subject: [PATCH 49/57] Use kube storage for the supervisor callback endpoint's fosite sessions --- cmd/pinniped-supervisor/main.go | 7 +- internal/oidc/provider/manager/manager.go | 42 +++++---- .../oidc/provider/manager/manager_test.go | 88 ++++++++++++++++--- 3 files changed, 107 insertions(+), 30 deletions(-) diff --git a/cmd/pinniped-supervisor/main.go b/cmd/pinniped-supervisor/main.go index d2bfc7f50..31f5dff80 100644 --- a/cmd/pinniped-supervisor/main.go +++ b/cmd/pinniped-supervisor/main.go @@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error { dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider() // OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux. - oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider) + oidProvidersManager := manager.NewManager( + healthMux, + dynamicJWKSProvider, + dynamicUpstreamIDPProvider, + kubeClient.CoreV1().Secrets(serverInstallationNamespace), + ) startControllers( ctx, diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 6bac2c600..b42382731 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/gorilla/securecookie" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" @@ -32,18 +33,25 @@ type Manager struct { nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs + secretsClient corev1client.SecretInterface } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. -func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager { +func NewManager( + nextHandler http.Handler, + dynamicJWKSProvider jwks.DynamicJWKSProvider, + idpListGetter oidc.IDPListGetter, + secretsClient corev1client.SecretInterface, +) *Manager { return &Manager{ providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler, dynamicJWKSProvider: dynamicJWKSProvider, idpListGetter: idpListGetter, + secretsClient: secretsClient, } } @@ -63,15 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { m.providerHandlers = make(map[string]http.Handler) for _, incomingProvider := range oidcProviders { - wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath - m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer()) + issuer := incomingProvider.Issuer() + issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() - jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath - m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider) + fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret // Use NullStorage for the authorize endpoint because we do not actually want to store anything until // the upstream callback endpoint is called later. - oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, incomingProvider.Issuer(), []byte("some secret - must have at least 32 bytes")) // TODO replace this secret + oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider) + + // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. + oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider) // TODO use different codecs for the state and the cookie, because: // 1. we would like to state to have an embedded expiration date while the cookie does not need that @@ -82,11 +92,14 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { var encoder = securecookie.New(encoderHashKey, encoderBlockKey) encoder.SetSerializer(securecookie.JSONEncoder{}) - authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler( - incomingProvider.Issuer(), + m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) + + m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider) + + m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler( + issuer, m.idpListGetter, - oauthHelper, + oauthHelperWithNullStorage, csrftoken.Generate, pkce.Generate, nonce.Generate, @@ -94,16 +107,15 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { encoder, ) - callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath - m.providerHandlers[callbackURL] = callback.NewHandler( + m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler( m.idpListGetter, - oauthHelper, + oauthHelperWithKubeStorage, encoder, encoder, - incomingProvider.Issuer()+oidc.CallbackEndpointPath, + issuer+oidc.CallbackEndpointPath, ) - plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) + plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 44ac63980..a3f8090dc 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -4,6 +4,7 @@ package manager import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -15,6 +16,7 @@ import ( "github.com/sclevine/spec" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" + "k8s.io/client-go/kubernetes/fake" "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/oidc" @@ -22,6 +24,9 @@ import ( "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" ) func TestManager(t *testing.T) { @@ -32,6 +37,7 @@ func TestManager(t *testing.T) { nextHandler http.HandlerFunc fallbackHandlerWasCalled bool dynamicJWKSProvider jwks.DynamicJWKSProvider + kubeClient *fake.Clientset ) const ( @@ -66,7 +72,7 @@ func TestManager(t *testing.T) { r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer) } - requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) { + requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) { recorder := httptest.NewRecorder() subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) @@ -81,18 +87,58 @@ func TestManager(t *testing.T) { "actual location %s did not start with expected prefix %s", actualLocation, expectedRedirectLocationPrefix, ) + + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + redirectStateParam := parsedLocation.Query().Get("state") + r.NotEmpty(redirectStateParam) + + cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2) + r.Len(cookieValueAndDirectivesSplit, 2) + cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=") + r.Len(cookieKeyValueSplit, 2) + csrfCookieName := cookieKeyValueSplit[0] + r.Equal("__Host-pinniped-csrf", csrfCookieName) + csrfCookieValue := cookieKeyValueSplit[1] + r.NotEmpty(csrfCookieValue) + + // Return the important parts of the response so we can use them in our next request to the callback endpoint + return csrfCookieValue, redirectStateParam } - requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) { + requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix)) + numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions()) + + getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix) + getRequest.AddCookie(&http.Cookie{ + Name: "__Host-pinniped-csrf", + Value: csrfCookieValue, + }) + subject.ServeHTTP(recorder, getRequest) r.False(fallbackHandlerWasCalled) - // Minimal check to ensure that the right endpoint was called - when we don't send a CSRF - // cookie to the callback endpoint, the callback endpoint responds with a 403. - r.Equal(http.StatusForbidden, recorder.Code) + // Check just enough of the response to ensure that we wired up the callback endpoint correctly. + // The endpoint's own unit tests cover everything else. + r.Equal(http.StatusFound, recorder.Code) + actualLocation := recorder.Header().Get("Location") + r.True( + strings.HasPrefix(actualLocation, downstreamRedirectURL), + "actual location %s did not start with expected prefix %s", + actualLocation, downstreamRedirectURL, + ) + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + actualLocationQueryParams := parsedLocation.Query() + r.Contains(actualLocationQueryParams, "code") + r.Equal("openid", actualLocationQueryParams.Get("scope")) + r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state")) + + // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. + r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3, + "did not perform any kube actions during the callback request, but should have") } requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { @@ -126,9 +172,22 @@ func TestManager(t *testing.T) { ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, Scopes: []string{"test-scope"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + return oidctypes.Token{}, + map[string]interface{}{ + "iss": "https://some-issuer.com", + "sub": "some-subject", + "username": "test-username", + "groups": "test-group1", + }, + nil + }, }) - subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) + kubeClient = fake.NewSimpleClientset() + secretsClient := kubeClient.CoreV1().Secrets("some-namespace") + + subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient) }) when("given no providers via SetProviders()", func() { @@ -191,19 +250,20 @@ func TestManager(t *testing.T) { // Hostnames are case-insensitive, so test that we can handle that. requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) - requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + csrfCookieValue, upstreamStateParam := + requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) callbackRequestParams := "?" + url.Values{ - "code": []string{"some-code"}, - "state": []string{"some-state-value"}, + "code": []string{"some-fake-code"}, + "state": []string{upstreamStateParam}, }.Encode() - requireCallbackRequestToBeHandled(issuer1, callbackRequestParams) - requireCallbackRequestToBeHandled(issuer2, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue) // // Hostnames are case-insensitive, so test that we can handle that. - requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams) - requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue) } when("given some valid providers via SetProviders()", func() { From 1d44a0cdfa01411c0d20e6409eefa197ab364bad Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 09:34:46 -0600 Subject: [PATCH 50/57] Add a small integration test library to dump pod logs on test failures. Signed-off-by: Matt Moyer --- test/library/dumplogs.go | 49 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 test/library/dumplogs.go diff --git a/test/library/dumplogs.go b/test/library/dumplogs.go new file mode 100644 index 000000000..33f33694d --- /dev/null +++ b/test/library/dumplogs.go @@ -0,0 +1,49 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package library + +import ( + "bufio" + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// DumpLogs is meant to be called in a `defer` to dump the logs of components in the cluster on a test failure. +func DumpLogs(t *testing.T, namespace string) { + // Only trigger on failed tests. + if !t.Failed() { + return + } + + kubeClient := NewClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + logTailLines := int64(40) + pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{}) + require.NoError(t, err) + + for _, pod := range pods.Items { + for _, container := range pod.Status.ContainerStatuses { + t.Logf("pod %s/%s container %s restarted %d times:", pod.Namespace, pod.Name, container.Name, container.RestartCount) + req := kubeClient.CoreV1().Pods(namespace).GetLogs(pod.Name, &corev1.PodLogOptions{ + Container: container.Name, + TailLines: &logTailLines, + }) + logReader, err := req.Stream(ctx) + require.NoError(t, err) + + scanner := bufio.NewScanner(logReader) + for scanner.Scan() { + t.Logf("%s/%s/%s > %s", pod.Namespace, pod.Name, container.Name, scanner.Text()) + } + require.NoError(t, scanner.Err()) + } + } +} From d7b1ab8e43175e13270251a0dadbb2e3f81363a8 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 09:35:28 -0600 Subject: [PATCH 51/57] Try to capture more logs from the TestSupervisorLogin test. Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 92e7c086d..5df74651c 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -33,6 +33,10 @@ import ( func TestSupervisorLogin(t *testing.T) { env := library.IntegrationEnv(t) + + // If anything in this test crashes, dump out the supervisor pod logs. + defer library.DumpLogs(t, env.SupervisorNamespace) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -92,9 +96,6 @@ func TestSupervisorLogin(t *testing.T) { TLS: &idpv1alpha1.TLSSpec{ CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), }, - AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{ - AdditionalScopes: []string{"email", "profile"}, - }, Client: idpv1alpha1.OIDCClient{ SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, }, From 954591d2db1bf6b175f5a3db879a5dfcc852c214 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 10:25:26 -0600 Subject: [PATCH 52/57] Add some debugging logs to our proxy client code. Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 5df74651c..3b5e50a9c 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -57,9 +57,13 @@ func TestSupervisorLogin(t *testing.T) { TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, Proxy: func(req *http.Request) (*url.URL, error) { if env.Proxy == "" { + t.Logf("passing request for %s with no proxy", req.URL) return nil, nil } - return url.Parse(env.Proxy) + proxyURL, err := url.Parse(env.Proxy) + require.NoError(t, err) + t.Logf("passing request for %s through proxy %s", req.URL, proxyURL.String()) + return proxyURL, nil }, }} From cb5e4948151791a79d1bec432543942b7a291f39 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 11:28:48 -0600 Subject: [PATCH 53/57] Dump out proxy access logs in TestSupervisorLogin. Signed-off-by: Matt Moyer --- test/deploy/dex/proxy.yaml | 16 ++++++++++++++++ test/integration/supervisor_login_test.go | 3 ++- test/library/dumplogs.go | 4 ++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/deploy/dex/proxy.yaml b/test/deploy/dex/proxy.yaml index 1d7d8a665..be4d4878a 100644 --- a/test/deploy/dex/proxy.yaml +++ b/test/deploy/dex/proxy.yaml @@ -20,6 +20,9 @@ spec: labels: app: proxy spec: + volumes: + - name: log-dir + emptyDir: {} containers: - name: proxy image: docker.io/getpinniped/test-forward-proxy @@ -34,6 +37,9 @@ spec: limits: cpu: "10m" memory: "64Mi" + volumeMounts: + - name: log-dir + mountPath: "/var/log/squid/" readinessProbe: tcpSocket: port: http @@ -41,6 +47,16 @@ spec: timeoutSeconds: 5 periodSeconds: 5 failureThreshold: 2 + - name: accesslogs + image: debian:10.6-slim + command: + - "/bin/sh" + - "-c" + args: + - tail -F /var/log/squid/access.log + volumeMounts: + - name: log-dir + mountPath: "/var/log/squid/" --- apiVersion: v1 kind: Service diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 3b5e50a9c..f666ef1a3 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -35,7 +35,8 @@ func TestSupervisorLogin(t *testing.T) { env := library.IntegrationEnv(t) // If anything in this test crashes, dump out the supervisor pod logs. - defer library.DumpLogs(t, env.SupervisorNamespace) + defer library.DumpLogs(t, env.SupervisorNamespace, "") + defer library.DumpLogs(t, "dex", "app=proxy") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() diff --git a/test/library/dumplogs.go b/test/library/dumplogs.go index 33f33694d..dc727e039 100644 --- a/test/library/dumplogs.go +++ b/test/library/dumplogs.go @@ -15,7 +15,7 @@ import ( ) // DumpLogs is meant to be called in a `defer` to dump the logs of components in the cluster on a test failure. -func DumpLogs(t *testing.T, namespace string) { +func DumpLogs(t *testing.T, namespace string, labelSelector string) { // Only trigger on failed tests. if !t.Failed() { return @@ -26,7 +26,7 @@ func DumpLogs(t *testing.T, namespace string) { defer cancel() logTailLines := int64(40) - pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{}) + pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector}) require.NoError(t, err) for _, pod := range pods.Items { From 408fbe4f76aa0dba5508bd636b5253e378c28f94 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 12:45:56 -0600 Subject: [PATCH 54/57] Parameterize the `supervisor_redirect_uri` in the test env Dex. Signed-off-by: Matt Moyer --- hack/lib/tilt/Tiltfile | 5 ++++- hack/prepare-for-integration-tests.sh | 4 ++++ test/deploy/dex/dex.yaml | 2 +- test/deploy/dex/values.yaml | 2 ++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/hack/lib/tilt/Tiltfile b/hack/lib/tilt/Tiltfile index 0c176e2c8..675b61a58 100644 --- a/hack/lib/tilt/Tiltfile +++ b/hack/lib/tilt/Tiltfile @@ -23,7 +23,10 @@ local_resource( # # Render the IDP installation manifest using ytt. -k8s_yaml(local(['ytt','--file', '../../../test/deploy/dex'])) +k8s_yaml(local(['ytt', + '--file', '../../../test/deploy/dex', + '--data-value', 'supervisor_redirect_uri=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback', +])) # Tell tilt to watch all of those files for changes. watch_file('../../../test/deploy/dex') diff --git a/hack/prepare-for-integration-tests.sh b/hack/prepare-for-integration-tests.sh index 97bdcceb7..4634330c7 100755 --- a/hack/prepare-for-integration-tests.sh +++ b/hack/prepare-for-integration-tests.sh @@ -184,6 +184,10 @@ if ! tilt_mode; then log_note "Deploying Dex to the cluster..." ytt --file . >"$manifest" + ytt --file . \ + --data-value "supervisor_redirect_uri=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback" \ + >"$manifest" + kubectl apply --dry-run=client -f "$manifest" # Validate manifest schema. kapp deploy --yes --app dex --diff-changes --file "$manifest" diff --git a/test/deploy/dex/dex.yaml b/test/deploy/dex/dex.yaml index 6a5ecfecc..cee9f3828 100644 --- a/test/deploy/dex/dex.yaml +++ b/test/deploy/dex/dex.yaml @@ -28,7 +28,7 @@ staticClients: name: 'Pinniped Supervisor' secret: pinniped-supervisor-secret redirectURIs: - - https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback + - #@ data.values.supervisor_redirect_uri enablePasswordDB: true staticPasswords: - username: "pinny" diff --git a/test/deploy/dex/values.yaml b/test/deploy/dex/values.yaml index 27022cdb5..8bb90da52 100644 --- a/test/deploy/dex/values.yaml +++ b/test/deploy/dex/values.yaml @@ -15,3 +15,5 @@ ports: #! External port where the proxy ends up exposed on localhost during tests. This value comes from #! our Kind configuration which maps 127.0.0.1:12346 to port 31235 on the Kind worker node. local: 12346 + +supervisor_redirect_uri: "" \ No newline at end of file From 8563c05bafd31470b54799f675db4ad8c17a548a Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 13:22:27 -0600 Subject: [PATCH 55/57] Tweak these timeouts to be a bit faster (and retrigger CI). Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index f666ef1a3..e926950e5 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -111,7 +111,7 @@ func TestSupervisorLogin(t *testing.T) { assert.Eventually(t, func() bool { discovery, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer) return err == nil - }, 60*time.Second, 1*time.Second) + }, 30*time.Second, 200*time.Millisecond) require.NoError(t, err) // Start a callback server on localhost. From 9455a66be8552a086f18f45daac3eb49aa4b7a9d Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 13:56:24 -0600 Subject: [PATCH 56/57] This trailing dash is now taken care of by the library method. Signed-off-by: Matt Moyer --- test/library/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/library/client.go b/test/library/client.go index d95f04263..b151b86b5 100644 --- a/test/library/client.go +++ b/test/library/client.go @@ -254,7 +254,7 @@ func CreateClientCredsSecret(t *testing.T, clientID string, clientSecret string) env := IntegrationEnv(t) return CreateTestSecret(t, env.SupervisorNamespace, - "test-client-creds-", + "test-client-creds", "secrets.pinniped.dev/oidc-client", map[string]string{ "clientID": clientID, From c8abc79d9b515f92e83ac84fbdbc80c484ebae0f Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 3 Dec 2020 14:24:26 -0600 Subject: [PATCH 57/57] Fix this comment (and retrigger CI). Signed-off-by: Matt Moyer --- test/integration/supervisor_login_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index e926950e5..8835752ff 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -34,7 +34,7 @@ import ( func TestSupervisorLogin(t *testing.T) { env := library.IntegrationEnv(t) - // If anything in this test crashes, dump out the supervisor pod logs. + // If anything in this test crashes, dump out the supervisor and proxy pod logs. defer library.DumpLogs(t, env.SupervisorNamespace, "") defer library.DumpLogs(t, "dex", "app=proxy")