diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index a2521488a..ee9a56216 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -26,14 +26,9 @@ const ( // ID token if the upstream OIDC IDP did not tell us to use another claim. defaultUpstreamUsernameClaim = "sub" - // defaultUpstreamGroupsClaim is what we will use to extract the groups from an upstream OIDC ID - // token if the upstream OIDC IDP did not tell us to use another claim. - defaultUpstreamGroupsClaim = "groups" - // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // information. - // TODO: should this be per-issuer? Or per version? - downstreamGroupsClaim = "oidc.pinniped.dev/groups" + downstreamGroupsClaim = "groups" ) func NewHandler( @@ -56,6 +51,7 @@ func NewHandler( downstreamAuthParams, err := url.ParseQuery(state.AuthParams) if err != nil { + plog.Error("error reading state downstream auth params", err) return httperr.New(http.StatusBadRequest, "error reading state downstream auth params") } @@ -63,11 +59,11 @@ func NewHandler( reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams} authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest) if err != nil { + plog.Error("error using state downstream auth params", err) return httperr.New(http.StatusBadRequest, "error using state downstream auth params") } // Grant the openid scope only if it was requested. - // TODO: shouldn't we be potentially granting more scopes than just openid... grantOpenIDScopeIfRequested(authorizeRequester) _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( @@ -77,52 +73,18 @@ func NewHandler( state.Nonce, ) if err != nil { + plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } - var username string - usernameClaim := upstreamIDPConfig.GetUsernameClaim() - if usernameClaim == "" { - usernameClaim = defaultUpstreamUsernameClaim - } - usernameAsInterface, ok := idTokenClaims[usernameClaim] - if !ok { - panic(err) // TODO - } - username, ok = usernameAsInterface.(string) - if !ok { - panic(err) // TODO + username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + if err != nil { + return err } - var groups []string - groupsClaim := upstreamIDPConfig.GetGroupsClaim() - if groupsClaim == "" { - groupsClaim = defaultUpstreamGroupsClaim - } - groupsAsInterface, ok := idTokenClaims[groupsClaim] - if !ok { - panic(err) // TODO - } - groups, ok = groupsAsInterface.([]string) - if !ok { - panic(err) // TODO - } - - now := time.Now() - authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ - Claims: &jwt.IDTokenClaims{ - Issuer: downstreamIssuer, - Subject: username, - Audience: []string{downstreamAuthParams.Get("client_id")}, - ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here - IssuedAt: now, // TODO test this - RequestedAt: now, // TODO test this - AuthTime: now, // TODO test this - Extra: map[string]interface{}{ - downstreamGroupsClaim: groups, - }, - }, - }) + groups := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups) + authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { panic(err) // TODO } @@ -222,3 +184,76 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } } + +func getUsernameFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) (string, error) { + usernameClaim := upstreamIDPConfig.GetUsernameClaim() + if usernameClaim == "" { + // TODO: if we use the default "sub" claim, maybe we should create the username with the issuer + // since the spec says the "sub" claim is only unique per issuer. + usernameClaim = defaultUpstreamUsernameClaim + } + + usernameAsInterface, ok := idTokenClaims[usernameClaim] + if !ok { + plog.Warning( + "no username claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), + "usernameClaim", usernameClaim, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") + } + + username, ok := usernameAsInterface.(string) + if !ok { + panic("todo bbb") // TODO + } + + return username, nil +} + +func getGroupsFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) []string { + groupsClaim := upstreamIDPConfig.GetGroupsClaim() + if groupsClaim == "" { + return nil + } + + groupsAsInterface, ok := idTokenClaims[groupsClaim] + if !ok { + panic("todo ccc") // TODO + } + + groups, ok := groupsAsInterface.([]string) + if !ok { + panic("todo ddd") // TODO + } + + return groups +} + +func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession { + now := time.Now() + openIDSession := &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Issuer: issuer, + Subject: username, + Audience: []string{clientID}, + ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here + IssuedAt: now, // TODO test this + RequestedAt: now, // TODO test this + AuthTime: now, // TODO test this + }, + } + if groups != nil { + openIDSession.Claims.Extra = map[string]interface{}{ + downstreamGroupsClaim: groups, + } + } + return openIDSession +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index a37133718..6902773bb 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -29,6 +29,16 @@ import ( const ( happyUpstreamIDPName = "upstream-idp-name" + + upstreamSubject = "abc123-some-guid" + upstreamUsername = "test-pinniped-username" + + upstreamUsernameClaim = "the-user-claim" + upstreamGroupsClaim = "the-groups-claim" +) + +var ( + upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} ) func TestCallbackEndpoint(t *testing.T) { @@ -36,63 +46,15 @@ func TestCallbackEndpoint(t *testing.T) { downstreamIssuer = "https://my-downstream-issuer.com/path" downstreamRedirectURI = "http://127.0.0.1/callback" happyUpstreamAuthcode = "upstream-auth-code" - upstreamUsername = "test-pinniped-username" downstreamClientID = "pinniped-cli" ) - var ( - upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} - ) - - upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - UsernameClaim: "the-user-claim", - GroupsClaim: "the-groups-claim", - Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, - map[string]interface{}{ - "the-user-claim": upstreamUsername, - "the-groups-claim": upstreamGroupMembership, - "other-claim": "should be ignored", - }, - nil - }, - } - - defaultClaimsUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, - map[string]interface{}{ - "sub": upstreamUsername, - "groups": upstreamGroupMembership, - "other-claim": "should be ignored", - }, - nil - }, - } - otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ Name: "other-upstream-idp-name", ClientID: "other-some-client-id", Scopes: []string{"other-scope1", "other-scope2"}, } - failedExchangeUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: upstreamOIDCIdentityProvider.ClientID, - UsernameClaim: upstreamOIDCIdentityProvider.UsernameClaim, - GroupsClaim: upstreamOIDCIdentityProvider.GroupsClaim, - Scopes: upstreamOIDCIdentityProvider.Scopes, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { - return oidcclient.Token{}, nil, errors.New("some exchange error") - }, - } - var stateEncoderHashKey = []byte("fake-hash-secret") var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES var cookieEncoderHashKey = []byte("fake-hash-secret2") @@ -210,16 +172,18 @@ func TestCallbackEndpoint(t *testing.T) { path string csrfCookie string - wantStatus int - wantBody string - wantRedirectLocationRegexp string - wantGrantedOpenidScope bool + wantStatus int + wantBody string + wantRedirectLocationRegexp string + wantGrantedOpenidScope bool + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenGroups []string wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -227,11 +191,13 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { - name: "upstream IDP uses default claims", - idp: defaultClaimsUpstreamOIDCIdentityProvider, + name: "upstream IDP provides no username or group claim, so we use default username claim and skip groups", + idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -239,6 +205,7 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantGrantedOpenidScope: true, wantBody: "", + wantDownstreamIDTokenSubject: upstreamSubject, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -290,7 +257,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), csrfCookie: happyCSRFCookie, @@ -299,7 +266,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's internal version does not match what we want", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongVersionState).String(), csrfCookie: happyCSRFCookie, @@ -308,7 +275,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params element is invalid", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), csrfCookie: happyCSRFCookie, @@ -317,7 +284,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params are missing required value (e.g., client_id)", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(missingClientIDState).String(), csrfCookie: happyCSRFCookie, @@ -326,12 +293,14 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params does not contain openid scope", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { @@ -345,7 +314,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "the CSRF cookie does not exist on request", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, @@ -353,7 +322,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", @@ -362,7 +331,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie csrf value does not match state csrf value", - idp: upstreamOIDCIdentityProvider, + idp: happyUpstream().Build(), method: http.MethodGet, path: newRequestPath().WithState(wrongCSRFValueState).String(), csrfCookie: happyCSRFCookie, @@ -373,7 +342,7 @@ func TestCallbackEndpoint(t *testing.T) { // Upstream exchange { name: "upstream auth code exchange fails", - idp: failedExchangeUpstreamOIDCIdentityProvider, + idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), csrfCookie: happyCSRFCookie, @@ -381,6 +350,16 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, + { + name: "upstream ID token does not contain requested username claim", + idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, } for _, test := range tests { test := test @@ -457,9 +436,13 @@ func TestCallbackEndpoint(t *testing.T) { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) - require.Equal(t, upstreamUsername, storedSession.Claims.Subject) + require.Equal(t, test.wantDownstreamIDTokenSubject, storedSession.Claims.Subject) require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience) - require.Equal(t, upstreamGroupMembership, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) + if test.wantDownstreamIDTokenGroups != nil { + require.Equal(t, test.wantDownstreamIDTokenGroups, storedSession.Claims.Extra["groups"]) + } else { + require.NotContains(t, storedSession.Claims.Extra, "groups") + } } else { require.Empty(t, rsp.Header().Values("Location")) } @@ -519,6 +502,63 @@ func (r *requestPath) String() string { return path + params.Encode() } +type upstreamOIDCIdentityProviderBuilder struct { + idToken map[string]interface{} + usernameClaim, groupsClaim string + authcodeExchangeErr error +} + +func happyUpstream() *upstreamOIDCIdentityProviderBuilder { + return &upstreamOIDCIdentityProviderBuilder{ + usernameClaim: upstreamUsernameClaim, + groupsClaim: upstreamGroupsClaim, + idToken: map[string]interface{}{ + "sub": upstreamSubject, + upstreamUsernameClaim: upstreamUsername, + upstreamGroupsClaim: upstreamGroupMembership, + "other-claim": "should be ignored", + }, + } +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder { + u.usernameClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder { + u.groupsClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name, value string) *upstreamOIDCIdentityProviderBuilder { + u.idToken[name] = value + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder { + delete(u.idToken, claim) + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder { + u.authcodeExchangeErr = err + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCIdentityProvider { + return testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + UsernameClaim: u.usernameClaim, + GroupsClaim: u.groupsClaim, + Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, u.idToken, u.authcodeExchangeErr + }, + } +} + func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { copied := url.Values{} for key, value := range query {