From 6fce1bd6bb64dd4e93098a2d3c5636f4b239b515 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Thu, 14 Jan 2021 17:21:41 -0500 Subject: [PATCH] Allow arrays of type interface and always set the groups claim to an array in the downstream token Signed-off-by: Margo Crawford --- internal/oidc/callback/callback_handler.go | 52 ++++++++++---- .../oidc/callback/callback_handler_test.go | 71 +++++++++++++------ test/integration/supervisor_login_test.go | 2 +- 3 files changed, 87 insertions(+), 38 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 417f1d2cc..3598c35c2 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -1,4 +1,4 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 // Package callback provides a handler for the OIDC callback endpoint. @@ -255,10 +255,10 @@ func getSubjectAndUsernameFromUpstreamIDToken( func getGroupsFromUpstreamIDToken( upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, idTokenClaims map[string]interface{}, -) (interface{}, error) { +) ([]string, error) { groupsClaim := upstreamIDPConfig.GetGroupsClaim() if groupsClaim == "" { - return nil, nil + return []string{}, nil } groupsAsInterface, ok := idTokenClaims[groupsClaim] @@ -269,12 +269,11 @@ func getGroupsFromUpstreamIDToken( "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), "groupsClaim", groupsClaim, ) - return nil, nil // the upstream IDP may have omitted the claim if the user has no groups + return []string{}, nil // the upstream IDP may have omitted the claim if the user has no groups } - groupsAsArray, okAsArray := groupsAsInterface.([]string) - groupsAsString, okAsString := groupsAsInterface.(string) - if !okAsArray && !okAsString { + groupsAsArray, okAsArray := extractGroups(groupsAsInterface) + if !okAsArray { plog.Warning( "groups claim in upstream ID token has invalid format", "upstreamName", upstreamIDPConfig.GetName(), @@ -284,13 +283,38 @@ func getGroupsFromUpstreamIDToken( return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format") } - if okAsArray { - return groupsAsArray, nil - } - return groupsAsString, nil + return groupsAsArray, nil } -func makeDownstreamSession(subject string, username string, groups interface{}) *openid.DefaultSession { +func extractGroups(groupsAsInterface interface{}) ([]string, bool) { + groupsAsString, okAsString := groupsAsInterface.(string) + if okAsString { + return []string{groupsAsString}, true + } + + groupsAsStringArray, okAsStringArray := groupsAsInterface.([]string) + if okAsStringArray { + return groupsAsStringArray, true + } + + groupsAsInterfaceArray, okAsArray := groupsAsInterface.([]interface{}) + if !okAsArray { + return nil, false + } + + groupsAsStrings := make([]string, len(groupsAsInterfaceArray)) + for i, groupAsInterface := range groupsAsInterfaceArray { + groupAsString, okAsString := groupAsInterface.(string) + if !okAsString { + return nil, false + } + groupsAsStrings[i] = groupAsString + } + + return groupsAsStrings, true +} + +func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession { now := time.Now().UTC() openIDSession := &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -301,9 +325,7 @@ func makeDownstreamSession(subject string, username string, groups interface{}) } openIDSession.Claims.Extra = map[string]interface{}{ oidc.DownstreamUsernameClaim: username, - } - if groups != nil { - openIDSession.Claims.Extra[oidc.DownstreamGroupsClaim] = groups + oidc.DownstreamGroupsClaim: groups, } return openIDSession } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 3b5659f4d..3e7db1afa 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package callback @@ -134,7 +134,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamGrantedScopes []string wantDownstreamIDTokenSubject string wantDownstreamIDTokenUsername string - wantDownstreamIDTokenGroups interface{} + wantDownstreamIDTokenGroups []string wantDownstreamRequestedScopes []string wantDownstreamNonce string wantDownstreamPKCEChallenge string @@ -172,7 +172,7 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "", wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject, - wantDownstreamIDTokenGroups: nil, + wantDownstreamIDTokenGroups: []string{}, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, @@ -210,7 +210,26 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "", wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenUsername: upstreamUsername, - wantDownstreamIDTokenGroups: "notAnArrayGroup1 notAnArrayGroup2", + wantDownstreamIDTokenGroups: []string{"notAnArrayGroup1 notAnArrayGroup2"}, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP's configured groups claim in the ID token is a slice of interfaces", + idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, []interface{}{"group1", "group2"}).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, + wantBody: "", + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenGroups: []string{"group1", "group2"}, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, @@ -437,6 +456,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamIDTokenUsername: upstreamUsername, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamIDTokenGroups: []string{}, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, @@ -482,6 +502,26 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, + { + name: "upstream ID token contains groups claim where one element is invalid", + idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, []interface{}{"foo", 7}).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token contains groups claim with invalid null type", + idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, nil).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, } for _, test := range tests { test := test @@ -779,7 +819,7 @@ func validateAuthcodeStorage( wantDownstreamGrantedScopes []string, wantDownstreamIDTokenSubject string, wantDownstreamIDTokenUsername string, - wantDownstreamIDTokenGroups interface{}, + wantDownstreamIDTokenGroups []string, wantDownstreamRequestedScopes []string, ) (*fosite.Request, *openid.DefaultSession) { t.Helper() @@ -818,23 +858,10 @@ func validateAuthcodeStorage( // Check the user's identity, which are put into the downstream ID token's subject, username and groups claims. require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"]) - if wantDownstreamIDTokenGroups != nil { //nolint:nestif // there are some nested if's here but its probably fine for a test - require.Len(t, actualClaims.Extra, 2) - wantArray, ok := wantDownstreamIDTokenGroups.([]string) - if ok { - require.ElementsMatch(t, wantArray, actualClaims.Extra["groups"]) - } else { - wantString, ok := wantDownstreamIDTokenGroups.(string) - if ok { - require.Equal(t, wantString, actualClaims.Extra["groups"]) - } else { - require.Fail(t, "wantDownstreamIDTokenGroups should be of type: either []string or string") - } - } - } else { - require.Len(t, actualClaims.Extra, 1) - require.NotContains(t, actualClaims.Extra, "groups") - } + require.Len(t, actualClaims.Extra, 2) + actualDownstreamIDTokenGroups := actualClaims.Extra["groups"] + require.NotNil(t, actualDownstreamIDTokenGroups) + require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualDownstreamIDTokenGroups) // Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time). testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor) diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 60ecbb672..3c3b966e8 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -207,7 +207,7 @@ func TestSupervisorLogin(t *testing.T) { tokenResponse, err := downstreamOAuth2Config.Exchange(oidcHTTPClientContext, authcode, pkceParam.Verifier()) require.NoError(t, err) - expectedIDTokenClaims := []string{"iss", "exp", "sub", "aud", "auth_time", "iat", "jti", "nonce", "rat", "username"} + expectedIDTokenClaims := []string{"iss", "exp", "sub", "aud", "auth_time", "iat", "jti", "nonce", "rat", "username", "groups"} verifyTokenResponse(t, tokenResponse, discovery, downstreamOAuth2Config, env.SupervisorTestUpstream.Issuer, nonceParam, expectedIDTokenClaims) // token exchange on the original token