diff --git a/internal/federationdomain/endpoints/callback/callback_handler_github_test.go b/internal/federationdomain/endpoints/callback/callback_handler_github_test.go index 435ed1f1e..1d506436d 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_github_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_github_test.go @@ -22,7 +22,9 @@ import ( supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" + "go.pinniped.dev/internal/federationdomain/oidcclientvalidator" "go.pinniped.dev/internal/federationdomain/storage" + "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -35,7 +37,7 @@ var ( githubUpstreamUsername = "some-github-login" githubUpstreamGroups = []string{"org1/team1", "org2/team2"} githubDownstreamSubject = fmt.Sprintf("https://github.com?idpName=%s&sub=%s", githubIDPName, githubUpstreamUsername) - githubUpstreamAccessToken = "some-opaque-access-token-from-github" + githubUpstreamAccessToken = "some-opaque-access-token-from-github" //nolint:gosec // this is not a credential happyDownstreamGitHubCustomSessionData = &psession.CustomSessionData{ Username: githubUpstreamUsername, @@ -74,6 +76,16 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { RedirectURI: happyUpstreamRedirectURI, } + // TODO: when we merge this file back into callback_handler_test.go, we do not need to copy this function + // because it is already in callback_handler_test.go + addFullyCapableDynamicClientAndSecretToKubeResources := func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) { + oidcClient, secret := testutil.FullyCapableOIDCClientAndStorageSecret(t, + "some-namespace", downstreamDynamicClientID, downstreamDynamicClientUID, downstreamRedirectURI, nil, + []string{testutil.HashedPassword1AtGoMinCost}, oidcclientvalidator.Validate) + require.NoError(t, supervisorClient.Tracker().Add(oidcClient)) + require.NoError(t, kubeClient.Tracker().Add(secret)) + } + tests := []struct { name string @@ -95,14 +107,18 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { wantDownstreamPKCEChallengeMethod string wantDownstreamCustomSessionData *psession.CustomSessionData wantDownstreamAdditionalClaims map[string]interface{} - - wantAuthcodeExchangeCall *expectedAuthcodeExchange + wantGitHubAuthcodeExchangeCall *expectedGitHubAuthcodeExchange }{ { name: "GitHub IDP: GET with good state and cookie and successful upstream token exchange returns 303 to downstream client callback", idps: testidplister.NewUpstreamIDPListerBuilder().WithGitHub( happyGitHubUpstream(). WithAccessToken(githubUpstreamAccessToken). + WithUser(&upstreamprovider.GitHubUser{ + Username: githubUpstreamUsername, + Groups: githubUpstreamGroups, + DownstreamSubject: githubDownstreamSubject, + }). Build()), method: http.MethodGet, path: newRequestPath().WithState( @@ -125,7 +141,7 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamGitHubCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantGitHubAuthcodeExchangeCall: &expectedGitHubAuthcodeExchange{ performedByUpstreamName: githubIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -135,6 +151,11 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { idps: testidplister.NewUpstreamIDPListerBuilder().WithGitHub( happyGitHubUpstream(). WithAccessToken(githubUpstreamAccessToken). + WithUser(&upstreamprovider.GitHubUser{ + Username: githubUpstreamUsername, + Groups: githubUpstreamGroups, + DownstreamSubject: githubDownstreamSubject, + }). Build()), method: http.MethodGet, kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, @@ -163,7 +184,7 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamGitHubCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantGitHubAuthcodeExchangeCall: &expectedGitHubAuthcodeExchange{ performedByUpstreamName: githubIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -204,13 +225,12 @@ func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { testutil.RequireSecurityHeadersWithFormPostPageCSPs(t, rsp) - require.NotNil(t, test.wantAuthcodeExchangeCall, "wantAuthcodeExchangeCall is required for testing purposes") + require.NotNil(t, test.wantGitHubAuthcodeExchangeCall, "wantOIDCAuthcodeExchangeCall is required for testing purposes") - test.wantAuthcodeExchangeCall.args.Ctx = reqContext - test.idps.RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens(t, - test.wantAuthcodeExchangeCall.performedByUpstreamName, - idpdiscoveryv1alpha1.IDPTypeGitHub, - test.wantAuthcodeExchangeCall.args, + test.wantGitHubAuthcodeExchangeCall.args.Ctx = reqContext + test.idps.RequireExactlyOneGitHubAuthcodeExchange(t, + test.wantGitHubAuthcodeExchangeCall.performedByUpstreamName, + test.wantGitHubAuthcodeExchangeCall.args, ) require.Equal(t, http.StatusSeeOther, rsp.Code) diff --git a/internal/federationdomain/endpoints/callback/callback_handler_test.go b/internal/federationdomain/endpoints/callback/callback_handler_test.go index 019a32223..7818e7f0a 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_test.go @@ -20,7 +20,6 @@ import ( "k8s.io/client-go/kubernetes/fake" configv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" - idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" @@ -93,6 +92,7 @@ var ( happyDownstreamRequestParamsQueryForDynamicClient = shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": downstreamDynamicClientID}, ) + happyDownstreamRequestParamsForDynamicClient = happyDownstreamRequestParamsQueryForDynamicClient.Encode() happyDownstreamCustomSessionData = &psession.CustomSessionData{ Username: oidcUpstreamUsername, @@ -107,7 +107,6 @@ var ( UpstreamSubject: oidcUpstreamSubject, }, } - happyDownstreamCustomSessionDataWithUsernameAndGroups = func(wantDownstreamUsername, wantUpstreamUsername string, wantUpstreamGroups []string) *psession.CustomSessionData { copyOfCustomSession := *happyDownstreamCustomSessionData copyOfOIDC := *(happyDownstreamCustomSessionData.OIDC) @@ -130,14 +129,6 @@ var ( UpstreamSubject: oidcUpstreamSubject, }, } - - addFullyCapableDynamicClientAndSecretToKubeResources = func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) { - oidcClient, secret := testutil.FullyCapableOIDCClientAndStorageSecret(t, - "some-namespace", downstreamDynamicClientID, downstreamDynamicClientUID, downstreamRedirectURI, nil, - []string{testutil.HashedPassword1AtGoMinCost}, oidcclientvalidator.Validate) - require.NoError(t, supervisorClient.Tracker().Add(oidcClient)) - require.NoError(t, kubeClient.Tracker().Add(secret)) - } ) func TestCallbackEndpoint(t *testing.T) { @@ -162,13 +153,13 @@ func TestCallbackEndpoint(t *testing.T) { happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) happyState := happyUpstreamStateParam().Build(t, happyStateCodec) - happyStateForDynamicClient := happyUpstreamStateParam().WithAuthorizeRequestParams(happyDownstreamRequestParamsQueryForDynamicClient.Encode()).Build(t, happyStateCodec) + happyStateForDynamicClient := happyUpstreamStateParamForDynamicClient().Build(t, happyStateCodec) encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue - happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeArgs{ + happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{ Authcode: happyUpstreamAuthcode, PKCECodeVerifier: oidcpkce.Code(happyDownstreamPKCE), ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), @@ -178,6 +169,14 @@ func TestCallbackEndpoint(t *testing.T) { // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid\+username\+groups&state=` + happyDownstreamState + addFullyCapableDynamicClientAndSecretToKubeResources := func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) { + oidcClient, secret := testutil.FullyCapableOIDCClientAndStorageSecret(t, + "some-namespace", downstreamDynamicClientID, downstreamDynamicClientUID, downstreamRedirectURI, nil, + []string{testutil.HashedPassword1AtGoMinCost}, oidcclientvalidator.Validate) + require.NoError(t, supervisorClient.Tracker().Add(oidcClient)) + require.NoError(t, kubeClient.Tracker().Add(secret)) + } + prefixUsernameAndGroupsPipeline := transformtestutil.NewPrefixingPipeline(t, transformationUsernamePrefix, transformationGroupsPrefix) rejectAuthPipeline := transformtestutil.NewRejectAllAuthPipeline(t) @@ -206,8 +205,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallengeMethod string wantDownstreamCustomSessionData *psession.CustomSessionData wantDownstreamAdditionalClaims map[string]interface{} - - wantAuthcodeExchangeCall *expectedAuthcodeExchange + wantOIDCAuthcodeExchangeCall *expectedOIDCAuthcodeExchange }{ { name: "GET with good state and cookie and successful upstream token exchange with response_mode=form_post returns 200 with HTML+JS form", @@ -235,7 +233,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -274,7 +272,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -302,7 +300,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -327,7 +325,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -351,7 +349,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamAccessTokenCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -386,7 +384,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -423,7 +421,7 @@ func TestCallbackEndpoint(t *testing.T) { UpstreamSubject: oidcUpstreamSubject, }, }, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -453,7 +451,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamIssuer+"?sub="+oidcUpstreamSubjectQueryEscaped, nil, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -483,7 +481,7 @@ func TestCallbackEndpoint(t *testing.T) { "joe@whitehouse.gov", oidcUpstreamGroupMembership, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -515,7 +513,7 @@ func TestCallbackEndpoint(t *testing.T) { "joe@whitehouse.gov", oidcUpstreamGroupMembership, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -548,7 +546,7 @@ func TestCallbackEndpoint(t *testing.T) { "joe", oidcUpstreamGroupMembership, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -565,7 +563,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -579,7 +577,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: access token was returned by upstream provider but there was no userinfo endpoint\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -593,7 +591,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: neither access token nor refresh token returned by upstream provider\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -607,7 +605,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: neither access token nor refresh token returned by upstream provider\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -621,7 +619,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: neither access token nor refresh token returned by upstream provider\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -635,7 +633,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: neither access token nor refresh token returned by upstream provider\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -653,7 +651,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has false value\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -683,7 +681,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamSubject, oidcUpstreamGroupMembership, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -713,7 +711,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamUsername, []string{"notAnArrayGroup1 notAnArrayGroup2"}, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -743,7 +741,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamUsername, []string{"group1", "group2"}, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -754,7 +752,7 @@ func TestCallbackEndpoint(t *testing.T) { kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, method: http.MethodGet, path: newRequestPath().WithState( - happyUpstreamStateParam(). + happyUpstreamStateParamForDynamicClient(). WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQueryForDynamicClient, map[string]string{"scope": "openid groups offline_access"}).Encode()). Build(t, happyStateCodec), @@ -773,7 +771,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -784,7 +782,7 @@ func TestCallbackEndpoint(t *testing.T) { kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, method: http.MethodGet, path: newRequestPath().WithState( - happyUpstreamStateParam(). + happyUpstreamStateParamForDynamicClient(). WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQueryForDynamicClient, map[string]string{"scope": "openid username offline_access"}).Encode()). Build(t, happyStateCodec), @@ -803,7 +801,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -846,7 +844,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -889,7 +887,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -918,7 +916,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamUsername, oidcUpstreamGroupMembership, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1004,7 +1002,7 @@ func TestCallbackEndpoint(t *testing.T) { Build(t, happyStateCodec), ).String(), csrfCookie: happyCSRFCookie, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1169,7 +1167,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1199,7 +1197,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1228,7 +1226,7 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamCustomSessionData: happyDownstreamCustomSessionData, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1285,7 +1283,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusBadGateway, wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", wantContentType: htmlContentType, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1301,7 +1299,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: required claim in upstream ID token missing\n", wantContentType: htmlContentType, - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1331,7 +1329,7 @@ func TestCallbackEndpoint(t *testing.T) { oidcUpstreamUsername, nil, ), - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1347,7 +1345,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1363,7 +1361,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token is empty\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1379,7 +1377,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token missing\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1395,7 +1393,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token is empty\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1411,7 +1409,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1427,7 +1425,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token missing\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1443,7 +1441,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token is empty\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1459,7 +1457,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1475,7 +1473,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1491,7 +1489,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1507,7 +1505,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: required claim in upstream ID token has invalid format\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1522,7 +1520,7 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusUnprocessableEntity, wantContentType: htmlContentType, wantBody: "Unprocessable Entity: configured identity policy rejected this authentication: authentication was rejected by a configured policy\n", - wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + wantOIDCAuthcodeExchangeCall: &expectedOIDCAuthcodeExchange{ performedByUpstreamName: happyUpstreamIDPName, args: happyExchangeAndValidateTokensArgs, }, @@ -1563,15 +1561,14 @@ func TestCallbackEndpoint(t *testing.T) { testutil.RequireSecurityHeadersWithFormPostPageCSPs(t, rsp) - if test.wantAuthcodeExchangeCall != nil { - test.wantAuthcodeExchangeCall.args.Ctx = reqContext - test.idps.RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens(t, - test.wantAuthcodeExchangeCall.performedByUpstreamName, - idpdiscoveryv1alpha1.IDPTypeOIDC, - test.wantAuthcodeExchangeCall.args, + if test.wantOIDCAuthcodeExchangeCall != nil { + test.wantOIDCAuthcodeExchangeCall.args.Ctx = reqContext + test.idps.RequireExactlyOneOIDCAuthcodeExchange(t, + test.wantOIDCAuthcodeExchangeCall.performedByUpstreamName, + test.wantOIDCAuthcodeExchangeCall.args, ) } else { - test.idps.RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t) + test.idps.RequireExactlyZeroAuthcodeExchanges(t) } require.Equal(t, test.wantStatus, rsp.Code) @@ -1637,7 +1634,12 @@ func TestCallbackEndpoint(t *testing.T) { } } -type expectedAuthcodeExchange struct { +type expectedOIDCAuthcodeExchange struct { + performedByUpstreamName string + args *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs +} + +type expectedGitHubAuthcodeExchange struct { performedByUpstreamName string args *oidctestutil.ExchangeAuthcodeArgs } @@ -1699,6 +1701,12 @@ func happyUpstreamStateParam() *oidctestutil.UpstreamStateParamBuilder { } } +func happyUpstreamStateParamForDynamicClient() *oidctestutil.UpstreamStateParamBuilder { + p := happyUpstreamStateParam() + p.P = happyDownstreamRequestParamsForDynamicClient + return p +} + func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName(happyUpstreamIDPName). diff --git a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go index 4efb1c048..c9de9036e 100644 --- a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go +++ b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go @@ -97,22 +97,29 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) LoginFromCallback( _ nonce.Nonce, // GitHub does not support OIDC, therefore there is no ID token that could contain the "nonce". redirectURI string, ) (*resolvedprovider.Identity, *resolvedprovider.IdentityLoginExtras, error) { - token, _ := p.Provider.ExchangeAuthcode( - ctx, - authCode, - redirectURI, - ) + accessToken, err := p.Provider.ExchangeAuthcode(ctx, authCode, redirectURI) + if err != nil { + return nil, nil, fmt.Errorf("failed to exchange auth code using GitHub API: %w", err) + } + + user, err := p.Provider.GetUser(ctx, accessToken) + if err != nil { + return nil, nil, fmt.Errorf("failed to get user info from GitHub API: %w", err) + } return &resolvedprovider.Identity{ - UpstreamUsername: "some-github-login", - UpstreamGroups: []string{"org1/team1", "org2/team2"}, - DownstreamSubject: "https://github.com?idpName=upstream-github-idp-name&sub=some-github-login", + UpstreamUsername: user.Username, + UpstreamGroups: user.Groups, + DownstreamSubject: user.DownstreamSubject, IDPSpecificSessionData: &psession.GitHubSessionData{ - UpstreamAccessToken: token, + UpstreamAccessToken: accessToken, }, }, - &resolvedprovider.IdentityLoginExtras{}, - nil + &resolvedprovider.IdentityLoginExtras{ + DownstreamAdditionalClaims: nil, // not using this for GitHub + Warnings: nil, // not using this for GitHub + }, + nil // no error } func (p *FederationDomainResolvedGitHubIdentityProvider) UpstreamRefresh( diff --git a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go index d54b4f57e..3f957ce37 100644 --- a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go +++ b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go @@ -4,6 +4,8 @@ package resolvedgithub import ( + "context" + "errors" "testing" "github.com/stretchr/testify/require" @@ -12,10 +14,11 @@ import ( idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" "go.pinniped.dev/internal/federationdomain/resolvedprovider" + "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/psession" + "go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/transformtestutil" "go.pinniped.dev/internal/upstreamgithub" - "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { @@ -59,11 +62,11 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { originalCustomSession := &psession.CustomSessionData{ Username: "fake-username", UpstreamUsername: "fake-upstream-username", - GitHub: &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "fake-upstream-access-token"}}}, + GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"}, } clonedCustomSession := subject.CloneIDPSpecificSessionDataFromSession(originalCustomSession) require.Equal(t, - &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "fake-upstream-access-token"}}}, + &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"}, clonedCustomSession, ) require.NotSame(t, originalCustomSession, clonedCustomSession) @@ -72,11 +75,11 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { Username: "fake-username2", UpstreamUsername: "fake-upstream-username2", } - subject.ApplyIDPSpecificSessionDataToSession(customSessionToBeMutated, &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "OTHER-upstream-access-token"}}}) + subject.ApplyIDPSpecificSessionDataToSession(customSessionToBeMutated, &psession.GitHubSessionData{UpstreamAccessToken: "OTHER-upstream-access-token"}) require.Equal(t, &psession.CustomSessionData{ Username: "fake-username2", UpstreamUsername: "fake-upstream-username2", - GitHub: &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "OTHER-upstream-access-token"}}}, + GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "OTHER-upstream-access-token"}, }, customSessionToBeMutated) redirectURL, err := subject.UpstreamAuthorizeRedirectURL( @@ -101,3 +104,139 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { redirectURL, ) } + +func TestLoginFromCallback(t *testing.T) { + uniqueCtx := context.WithValue(context.Background(), "some-unique-key", "some-value") //nolint:staticcheck // okay to use string key for test + + tests := []struct { + name string + provider *oidctestutil.TestUpstreamGitHubIdentityProvider + authcode string + redirectURI string + + wantExchangeAuthcodeCall bool + wantExchangeAuthcodeArgs *oidctestutil.ExchangeAuthcodeArgs + wantGetUserCall bool + wantGetUserArgs *oidctestutil.GetUserArgs + wantIdentity *resolvedprovider.Identity + wantExtras *resolvedprovider.IdentityLoginExtras + wantErr string + }{ + { + name: "happy path", + provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder(). + WithAccessToken("fake-access-token"). + WithUser(&upstreamprovider.GitHubUser{ + Username: "fake-username", + Groups: []string{"fake-group1", "fake-group2"}, + DownstreamSubject: "https://fake-downstream-subject", + }). + Build(), + authcode: "fake-authcode", + redirectURI: "https://fake-redirect-uri", + wantExchangeAuthcodeCall: true, + wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{ + Ctx: uniqueCtx, + Authcode: "fake-authcode", + RedirectURI: "https://fake-redirect-uri", + }, + wantGetUserCall: true, + wantGetUserArgs: &oidctestutil.GetUserArgs{ + Ctx: uniqueCtx, + AccessToken: "fake-access-token", + }, + wantIdentity: &resolvedprovider.Identity{ + UpstreamUsername: "fake-username", + UpstreamGroups: []string{"fake-group1", "fake-group2"}, + DownstreamSubject: "https://fake-downstream-subject", + IDPSpecificSessionData: &psession.GitHubSessionData{ + UpstreamAccessToken: "fake-access-token", + }, + }, + wantExtras: &resolvedprovider.IdentityLoginExtras{}, + }, + { + name: "error while exchanging authcode", + provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder(). + WithAuthcodeExchangeError(errors.New("fake authcode exchange error")). + Build(), + authcode: "fake-authcode", + redirectURI: "https://fake-redirect-uri", + wantExchangeAuthcodeCall: true, + wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{ + Ctx: uniqueCtx, + Authcode: "fake-authcode", + RedirectURI: "https://fake-redirect-uri", + }, + wantGetUserCall: false, + wantIdentity: nil, + wantExtras: nil, + wantErr: "failed to exchange auth code using GitHub API: fake authcode exchange error", + }, + { + name: "error while getting user info", + provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder(). + WithAccessToken("fake-access-token"). + WithGetUserError(errors.New("fake user info error")). + Build(), + authcode: "fake-authcode", + redirectURI: "https://fake-redirect-uri", + wantExchangeAuthcodeCall: true, + wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{ + Ctx: uniqueCtx, + Authcode: "fake-authcode", + RedirectURI: "https://fake-redirect-uri", + }, + wantGetUserCall: true, + wantGetUserArgs: &oidctestutil.GetUserArgs{ + Ctx: uniqueCtx, + AccessToken: "fake-access-token", + }, + wantIdentity: nil, + wantExtras: nil, + wantErr: "failed to get user info from GitHub API: fake user info error", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + transforms := transformtestutil.NewRejectAllAuthPipeline(t) + + subject := FederationDomainResolvedGitHubIdentityProvider{ + DisplayName: "fake-display-name", + Provider: test.provider, + SessionProviderType: psession.ProviderTypeGitHub, + Transforms: transforms, + } + + identity, loginExtras, err := subject.LoginFromCallback(uniqueCtx, + test.authcode, + "pkce-will-be-ignored", + "nonce-will-be-ignored", + test.redirectURI, + ) + + if test.wantExchangeAuthcodeCall { + require.Equal(t, 1, test.provider.ExchangeAuthcodeCallCount()) + require.Equal(t, test.wantExchangeAuthcodeArgs, test.provider.ExchangeAuthcodeArgs(0)) + } else { + require.Zero(t, test.provider.ExchangeAuthcodeCallCount()) + } + + if test.wantGetUserCall { + require.Equal(t, 1, test.provider.GetUserCallCount()) + require.Equal(t, test.wantGetUserArgs, test.provider.GetUserArgs(0)) + } else { + require.Zero(t, test.provider.GetUserCallCount()) + } + + if test.wantErr == "" { + require.NoError(t, err) + } else { + require.EqualError(t, err, test.wantErr) + } + require.Equal(t, test.wantExtras, loginExtras) + require.Equal(t, test.wantIdentity, identity) + }) + } +} diff --git a/internal/federationdomain/upstreamprovider/upsteam_provider.go b/internal/federationdomain/upstreamprovider/upsteam_provider.go index be3229be3..cb52dcedf 100644 --- a/internal/federationdomain/upstreamprovider/upsteam_provider.go +++ b/internal/federationdomain/upstreamprovider/upsteam_provider.go @@ -127,6 +127,12 @@ type UpstreamLDAPIdentityProviderI interface { PerformRefresh(ctx context.Context, storedRefreshAttributes RefreshAttributes, idpDisplayName string) (groups []string, err error) } +type GitHubUser struct { + Username string // could be login name, id, or login:id + Groups []string // could be names or slugs + DownstreamSubject string // the whole downstream subject URI +} + type UpstreamGithubIdentityProviderI interface { UpstreamIdentityProviderI @@ -159,21 +165,11 @@ type UpstreamGithubIdentityProviderI interface { // It will never include a username or password in the authority section. GetAuthorizationURL() string - // TODO: This interface should be easily mockable to avoid all interactions with the actual server. - // What interactions with the server do we want to hide behind this interface? Something like this? - // ExchangeAuthcode(ctx, authcode, redirectURI) (AccessToken, error) - // GetUser(ctx, accessToken) (User, error) - // GetUserOrgs(ctx, accessToken) ([]Org, error) - // GetUserTeams(ctx, accessToken) ([]Team, error) - // Or maybe higher level interface like this? - // ExchangeAuthcode(ctx, authcode, redirectURI) (AccessToken, error) - // GetUser(ctx, accessToken) (User, error) // in this case User would include team and org info - // ExchangeAuthcode performs an upstream GitHub authorization code exchange. // Returns the raw access token. The access token expiry is not known. - ExchangeAuthcode( - ctx context.Context, - authcode string, - redirectURI string, - ) (string, error) + ExchangeAuthcode(ctx context.Context, authcode string, redirectURI string) (string, error) + + // GetUser calls the user, orgs, and teams APIs of GitHub using the accessToken. + // It validates any required org memberships. It returns a User or an error. + GetUser(ctx context.Context, accessToken string) (*GitHubUser, error) } diff --git a/internal/testutil/oidctestutil/testgithubprovider.go b/internal/testutil/oidctestutil/testgithubprovider.go index 20b97ad93..1f793aeac 100644 --- a/internal/testutil/oidctestutil/testgithubprovider.go +++ b/internal/testutil/oidctestutil/testgithubprovider.go @@ -13,6 +13,21 @@ import ( "go.pinniped.dev/internal/idtransform" ) +// ExchangeAuthcodeArgs is used to spy on calls to +// TestUpstreamGitHubIdentityProvider.ExchangeAuthcodeFunc(). +type ExchangeAuthcodeArgs struct { + Ctx context.Context + Authcode string + RedirectURI string +} + +// GetUserArgs is used to spy on calls to +// TestUpstreamGitHubIdentityProvider.GetUserFunc(). +type GetUserArgs struct { + Ctx context.Context + AccessToken string +} + type TestUpstreamGitHubIdentityProviderBuilder struct { name string resourceUID types.UID @@ -24,10 +39,10 @@ type TestUpstreamGitHubIdentityProviderBuilder struct { groupNameAttribute v1alpha1.GitHubGroupNameAttribute allowedOrganizations []string authorizationURL string - - // Assertions stuff - authcodeExchangeErr error - accessToken string + authcodeExchangeErr error + accessToken string + getUserErr error + getUserUser *upstreamprovider.GitHubUser } func (u *TestUpstreamGitHubIdentityProviderBuilder) WithName(value string) *TestUpstreamGitHubIdentityProviderBuilder { @@ -80,8 +95,18 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAccessToken(token string return u } -func (u *TestUpstreamGitHubIdentityProviderBuilder) WithEmptyAccessToken() *TestUpstreamGitHubIdentityProviderBuilder { - u.accessToken = "" +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthcodeExchangeError(err error) *TestUpstreamGitHubIdentityProviderBuilder { + u.authcodeExchangeErr = err + return u +} + +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithUser(user *upstreamprovider.GitHubUser) *TestUpstreamGitHubIdentityProviderBuilder { + u.getUserUser = user + return u +} + +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGetUserError(err error) *TestUpstreamGitHubIdentityProviderBuilder { + u.getUserErr = err return u } @@ -96,8 +121,8 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubI } return &TestUpstreamGitHubIdentityProvider{ Name: u.name, - ResourceUID: u.resourceUID, ClientID: u.clientID, + ResourceUID: u.resourceUID, Scopes: u.scopes, DisplayNameForFederationDomain: u.displayNameForFederationDomain, TransformsForFederationDomain: u.transformsForFederationDomain, @@ -105,7 +130,12 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubI GroupNameAttribute: u.groupNameAttribute, AllowedOrganizations: u.allowedOrganizations, AuthorizationURL: u.authorizationURL, - + GetUserFunc: func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) { + if u.getUserErr != nil { + return nil, u.getUserErr + } + return u.getUserUser, nil + }, ExchangeAuthcodeFunc: func(ctx context.Context, authcode string) (string, error) { if u.authcodeExchangeErr != nil { return "", u.authcodeExchangeErr @@ -130,16 +160,14 @@ type TestUpstreamGitHubIdentityProvider struct { GroupNameAttribute v1alpha1.GitHubGroupNameAttribute AllowedOrganizations []string AuthorizationURL string + GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) + ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error) - authcodeExchangeErr error - - ExchangeAuthcodeFunc func( - ctx context.Context, - authcode string, - ) (string, error) - + // Fields for tracking actual calls make to mock functions. exchangeAuthcodeCallCount int exchangeAuthcodeArgs []*ExchangeAuthcodeArgs + getUserCallCount int + getUserArgs []*GetUserArgs } var _ upstreamprovider.UpstreamGithubIdentityProviderI = &TestUpstreamGitHubIdentityProvider{} @@ -203,3 +231,26 @@ func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeArgs(call int) *Exc } return u.exchangeAuthcodeArgs[call] } + +func (u *TestUpstreamGitHubIdentityProvider) GetUser(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) { + if u.getUserArgs == nil { + u.getUserArgs = make([]*GetUserArgs, 0) + } + u.getUserCallCount++ + u.getUserArgs = append(u.getUserArgs, &GetUserArgs{ + Ctx: ctx, + AccessToken: accessToken, + }) + return u.GetUserFunc(ctx, accessToken) +} + +func (u *TestUpstreamGitHubIdentityProvider) GetUserCallCount() int { + return u.getUserCallCount +} + +func (u *TestUpstreamGitHubIdentityProvider) GetUserArgs(call int) *GetUserArgs { + if u.getUserArgs == nil { + u.getUserArgs = make([]*GetUserArgs, 0) + } + return u.getUserArgs[call] +} diff --git a/internal/testutil/oidctestutil/testoidcprovider.go b/internal/testutil/oidctestutil/testoidcprovider.go index 9530db42e..489f6304c 100644 --- a/internal/testutil/oidctestutil/testoidcprovider.go +++ b/internal/testutil/oidctestutil/testoidcprovider.go @@ -18,9 +18,9 @@ import ( oidcpkce "go.pinniped.dev/pkg/oidcclient/pkce" ) -// ExchangeAuthcodeArgs is used to spy on calls to +// ExchangeAuthcodeAndValidateTokenArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). -type ExchangeAuthcodeArgs struct { +type ExchangeAuthcodeAndValidateTokenArgs struct { Ctx context.Context Authcode string PKCECodeVerifier oidcpkce.Code @@ -101,7 +101,7 @@ type TestUpstreamOIDCIdentityProvider struct { // Fields for tracking actual calls make to mock functions. exchangeAuthcodeAndValidateTokensCallCount int - exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeArgs + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs performRefreshCallCount int @@ -180,10 +180,10 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( redirectURI string, ) (*oidctypes.Token, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { - u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeArgs, 0) + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } u.exchangeAuthcodeAndValidateTokensCallCount++ - u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeArgs{ + u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ Ctx: ctx, Authcode: authcode, PKCECodeVerifier: pkceCodeVerifier, @@ -197,9 +197,9 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCall return u.exchangeAuthcodeAndValidateTokensCallCount } -func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeArgs { +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { if u.exchangeAuthcodeAndValidateTokensArgs == nil { - u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeArgs, 0) + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } return u.exchangeAuthcodeAndValidateTokensArgs[call] } diff --git a/internal/testutil/testidplister/testidplister.go b/internal/testutil/testidplister/testidplister.go index 2823a5188..d7fed23eb 100644 --- a/internal/testutil/testidplister/testidplister.go +++ b/internal/testutil/testidplister/testidplister.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" - idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/resolvedprovider" "go.pinniped.dev/internal/federationdomain/resolvedprovider/resolvedgithub" @@ -264,48 +263,59 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPasswordCredentialsG ) } -func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens( +func (b *UpstreamIDPListerBuilder) RequireExactlyOneOIDCAuthcodeExchange( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs, +) { + t.Helper() + var actualArgs *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCount := 0 + for _, upstream := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstream.ExchangeAuthcodeAndValidateTokensCallCount() + actualCallCount += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstream.Name + actualArgs = upstream.ExchangeAuthcodeAndValidateTokensArgs(0) + } + } + require.Equal(t, 1, actualCallCount, + "expected exactly one call to OIDC ExchangeAuthcodeAndValidateTokens()", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "OIDC ExchangeAuthcodeAndValidateTokens() was called on the wrong upstream name", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyOneGitHubAuthcodeExchange( t *testing.T, expectedPerformedByUpstreamName string, - expectedPerformedByUpstreamType idpdiscoveryv1alpha1.IDPType, expectedArgs *oidctestutil.ExchangeAuthcodeArgs, ) { t.Helper() var actualArgs *oidctestutil.ExchangeAuthcodeArgs var actualNameOfUpstreamWhichMadeCall string - var actualTypeOfUpstreamWhichMadeCall idpdiscoveryv1alpha1.IDPType - actualCallCountAcrossAllOIDCAndGitHubUpstreams := 0 - for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - callCountOnThisUpstream := upstreamOIDC.ExchangeAuthcodeAndValidateTokensCallCount() - actualCallCountAcrossAllOIDCAndGitHubUpstreams += callCountOnThisUpstream + actualCallCount := 0 + for _, upstream := range b.upstreamGitHubIdentityProviders { + callCountOnThisUpstream := upstream.ExchangeAuthcodeCallCount() + actualCallCount += callCountOnThisUpstream if callCountOnThisUpstream == 1 { - actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name - actualTypeOfUpstreamWhichMadeCall = idpdiscoveryv1alpha1.IDPTypeOIDC - actualArgs = upstreamOIDC.ExchangeAuthcodeAndValidateTokensArgs(0) + actualNameOfUpstreamWhichMadeCall = upstream.Name + actualArgs = upstream.ExchangeAuthcodeArgs(0) } } - for _, upstreamGitHub := range b.upstreamGitHubIdentityProviders { - callCountOnThisUpstream := upstreamGitHub.ExchangeAuthcodeCallCount() - actualCallCountAcrossAllOIDCAndGitHubUpstreams += callCountOnThisUpstream - if callCountOnThisUpstream == 1 { - actualNameOfUpstreamWhichMadeCall = upstreamGitHub.Name - actualTypeOfUpstreamWhichMadeCall = idpdiscoveryv1alpha1.IDPTypeGitHub - actualArgs = upstreamGitHub.ExchangeAuthcodeArgs(0) - } - } - require.Equal(t, 1, actualCallCountAcrossAllOIDCAndGitHubUpstreams, - "expected exactly one call to (OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount()", + require.Equal(t, 1, actualCallCount, + "expected exactly one call to GitHub ExchangeAuthcode()", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, - "(OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount() was called on the wrong upstream name", - ) - require.Equal(t, expectedPerformedByUpstreamType, actualTypeOfUpstreamWhichMadeCall, - "(OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount() was called on the wrong upstream type", + "GitHub ExchangeAuthcode() was called on the wrong upstream name", ) require.Equal(t, expectedArgs, actualArgs) } -func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t *testing.T) { +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroAuthcodeExchanges(t *testing.T) { t.Helper() actualCallCount := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { @@ -316,7 +326,7 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV } require.Equal(t, 0, actualCallCount, - "expected exactly zero calls to (OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount()", + "expected exactly zero calls to OIDC ExchangeAuthcodeAndValidateTokens() or GitHub ExchangeAuthcode()", ) } diff --git a/internal/upstreamgithub/upstreamgithub.go b/internal/upstreamgithub/upstreamgithub.go index f76a02aa5..67f8b683d 100644 --- a/internal/upstreamgithub/upstreamgithub.go +++ b/internal/upstreamgithub/upstreamgithub.go @@ -8,6 +8,7 @@ import ( "context" "net/http" + coreosoidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/types" @@ -88,12 +89,37 @@ func (p *Provider) GetAuthorizationURL() string { return p.c.OAuth2Config.Endpoint.AuthURL } -func (p *Provider) ExchangeAuthcode(_ context.Context, _ string, _ string) (string, error) { - //TODO implement me - panic("implement me") +func (p *Provider) ExchangeAuthcode(ctx context.Context, authcode string, redirectURI string) (string, error) { + // TODO: write tests for this + panic("write some tests for this sketch of the implementation, maybe by running a test server in the unit tests") + //nolint:govet // this code is intentionally unreachable until we resolve the todos + tok, err := p.c.OAuth2Config.Exchange( + coreosoidc.ClientContext(ctx, p.c.HttpClient), + authcode, + oauth2.SetAuthURLParam("redirect_uri", redirectURI), + ) + if err != nil { + return "", err + } + return tok.AccessToken, nil } -// GetConfig returns the config. This is not part of the interface and is mostly just for testing. +func (p *Provider) GetUser(_ctx context.Context, _accessToken string) (*upstreamprovider.GitHubUser, error) { + // TODO Implement this to make several https calls to github to learn about the user, using a lower-level githubclient package. + // Pass the ctx, accessToken, p.c.HttpClient, and p.c.APIBaseURL to the lower-level package's functions. + // TODO: Reject the auth if the user does not belong to any of p.c.AllowedOrganizations (unless p.c.AllowedOrganizations is empty). + // TODO: Make use of p.c.UsernameAttribute and p.c.GroupNameAttribute when deciding the username and group names. + // TODO: Determine the downstream subject by first writing a helper in downstream_subject.go and then calling it here. + panic("implement me") + //nolint:govet // this code is intentionally unreachable until we resolve the todos + return &upstreamprovider.GitHubUser{ + Username: "TODO", + Groups: []string{"org/TODO"}, + DownstreamSubject: "TODO", + }, nil +} + +// GetConfig returns the config. This is not part of the UpstreamGithubIdentityProviderI interface and is just for testing. func (p *Provider) GetConfig() ProviderConfig { return p.c }