From b7f79f0adc365d9b21c4cf0bc9b8212cbc04beca Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Fri, 17 May 2024 11:21:23 -0500 Subject: [PATCH] Add github-specific tests in callback_handler_github_test.go Co-authored-by: Ryan Richard --- .../callback/callback_handler_github_test.go | 252 ++++++++++++++++++ .../callback/callback_handler_test.go | 41 ++- .../resolved_github_provider.go | 27 +- .../resolved_github_provider_test.go | 9 +- .../upstreamprovider/upsteam_provider.go | 8 + .../expected_upstream_state_param.go | 11 +- .../oidctestutil/testgithubprovider.go | 61 +++++ .../testutil/oidctestutil/testoidcprovider.go | 14 +- .../testutil/testidplister/testidplister.go | 41 ++- internal/upstreamgithub/upstreamgithub.go | 6 + 10 files changed, 418 insertions(+), 52 deletions(-) create mode 100644 internal/federationdomain/endpoints/callback/callback_handler_github_test.go diff --git a/internal/federationdomain/endpoints/callback/callback_handler_github_test.go b/internal/federationdomain/endpoints/callback/callback_handler_github_test.go new file mode 100644 index 000000000..435ed1f1e --- /dev/null +++ b/internal/federationdomain/endpoints/callback/callback_handler_github_test.go @@ -0,0 +1,252 @@ +// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package callback + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" + + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" + supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" + "go.pinniped.dev/internal/federationdomain/endpoints/jwks" + "go.pinniped.dev/internal/federationdomain/oidc" + "go.pinniped.dev/internal/federationdomain/storage" + "go.pinniped.dev/internal/psession" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/oidctestutil" + "go.pinniped.dev/internal/testutil/testidplister" +) + +var ( + githubIDPName = "upstream-github-idp-name" + githubIDPResourceUID = types.UID("upstream-github-idp-resource-uid") + githubUpstreamUsername = "some-github-login" + githubUpstreamGroups = []string{"org1/team1", "org2/team2"} + githubDownstreamSubject = fmt.Sprintf("https://github.com?idpName=%s&sub=%s", githubIDPName, githubUpstreamUsername) + githubUpstreamAccessToken = "some-opaque-access-token-from-github" + + happyDownstreamGitHubCustomSessionData = &psession.CustomSessionData{ + Username: githubUpstreamUsername, + UpstreamUsername: githubUpstreamUsername, + UpstreamGroups: githubUpstreamGroups, + ProviderUID: githubIDPResourceUID, + ProviderName: githubIDPName, + ProviderType: psession.ProviderTypeGitHub, + GitHub: &psession.GitHubSessionData{ + UpstreamAccessToken: githubUpstreamAccessToken, + }, + } +) + +func TestCallbackEndpointWithGitHubIdentityProviders(t *testing.T) { + require.Len(t, happyDownstreamState, 8, "we expect fosite to allow 8 byte state params, so we want to test that boundary case") + + var stateEncoderHashKey = []byte("fake-hash-secret") + var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + var cookieEncoderHashKey = []byte("fake-hash-secret2") + var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) + require.NoError(t, err) + happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + + happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeArgs{ + Authcode: happyUpstreamAuthcode, + RedirectURI: happyUpstreamRedirectURI, + } + + tests := []struct { + name string + + idps *testidplister.UpstreamIDPListerBuilder + kubeResources func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) + method string + path string + csrfCookie string + + wantRedirectLocationRegexp string + wantDownstreamGrantedScopes []string + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenUsername string + wantDownstreamIDTokenGroups []string + wantDownstreamRequestedScopes []string + wantDownstreamNonce string + wantDownstreamClientID string + wantDownstreamPKCEChallenge string + wantDownstreamPKCEChallengeMethod string + wantDownstreamCustomSessionData *psession.CustomSessionData + wantDownstreamAdditionalClaims map[string]interface{} + + wantAuthcodeExchangeCall *expectedAuthcodeExchange + }{ + { + name: "GitHub IDP: GET with good state and cookie and successful upstream token exchange returns 303 to downstream client callback", + idps: testidplister.NewUpstreamIDPListerBuilder().WithGitHub( + happyGitHubUpstream(). + WithAccessToken(githubUpstreamAccessToken). + Build()), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithUpstreamIDPName(githubIDPName). + WithUpstreamIDPType(idpdiscoveryv1alpha1.IDPTypeGitHub). + WithAuthorizeRequestParams( + happyDownstreamRequestParamsQuery.Encode(), + ).Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=` + regexp.QuoteMeta(strings.Join(happyDownstreamScopesGranted, "+")) + `&state=` + happyDownstreamState, + wantDownstreamIDTokenSubject: githubDownstreamSubject, + wantDownstreamIDTokenUsername: githubUpstreamUsername, + wantDownstreamIDTokenGroups: githubUpstreamGroups, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamNonce: downstreamNonce, + wantDownstreamClientID: downstreamPinnipedClientID, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantDownstreamCustomSessionData: happyDownstreamGitHubCustomSessionData, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: githubIDPName, + args: happyExchangeAndValidateTokensArgs, + }, + }, + { + name: "GitHub IDP: GET with good state and cookie and successful upstream token exchange with dynamic client returns 303 to downstream client callback, with dynamic client", + idps: testidplister.NewUpstreamIDPListerBuilder().WithGitHub( + happyGitHubUpstream(). + WithAccessToken(githubUpstreamAccessToken). + Build()), + method: http.MethodGet, + kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithUpstreamIDPName(githubIDPName). + WithUpstreamIDPType(idpdiscoveryv1alpha1.IDPTypeGitHub). + WithAuthorizeRequestParams( + shallowCopyAndModifyQuery( + happyDownstreamRequestParamsQuery, + map[string]string{ + "client_id": downstreamDynamicClientID, + }, + ).Encode(), + ).Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=` + regexp.QuoteMeta(strings.Join(happyDownstreamScopesGranted, "+")) + `&state=` + happyDownstreamState, + wantDownstreamIDTokenSubject: githubDownstreamSubject, + wantDownstreamIDTokenUsername: githubUpstreamUsername, + wantDownstreamIDTokenGroups: githubUpstreamGroups, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamNonce: downstreamNonce, + wantDownstreamClientID: downstreamDynamicClientID, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantDownstreamCustomSessionData: happyDownstreamGitHubCustomSessionData, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: githubIDPName, + args: happyExchangeAndValidateTokensArgs, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + kubeClient := fake.NewSimpleClientset() + supervisorClient := supervisorfake.NewSimpleClientset() + secrets := kubeClient.CoreV1().Secrets("some-namespace") + oidcClientsClient := supervisorClient.ConfigV1alpha1().OIDCClients("some-namespace") + + if test.kubeResources != nil { + test.kubeResources(t, supervisorClient, kubeClient) + } + + // Configure fosite the same way that the production code would. + // Inject this into our test subject at the last second, so we get a fresh storage for every test. + timeoutsConfiguration := oidc.DefaultOIDCTimeoutsConfiguration() + // Use lower minimum required bcrypt cost than we would use in production to keep unit the tests fast. + oauthStore := storage.NewKubeStorage(secrets, oidcClientsClient, timeoutsConfiguration, bcrypt.MinCost) + hmacSecretFunc := func() []byte { return []byte("some secret - must have at least 32 bytes") } + require.GreaterOrEqual(t, len(hmacSecretFunc()), 32, "fosite requires that hmac secrets have at least 32 bytes") + jwksProviderIsUnused := jwks.NewDynamicJWKSProvider() + oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecretFunc, jwksProviderIsUnused, timeoutsConfiguration) + + subject := NewHandler(test.idps.BuildFederationDomainIdentityProvidersListerFinder(), oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") + req := httptest.NewRequest(test.method, test.path, nil).WithContext(reqContext) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } + rsp := httptest.NewRecorder() + subject.ServeHTTP(rsp, req) + t.Logf("response: %#v", rsp) + t.Logf("response body: %q", rsp.Body.String()) + + testutil.RequireSecurityHeadersWithFormPostPageCSPs(t, rsp) + + require.NotNil(t, test.wantAuthcodeExchangeCall, "wantAuthcodeExchangeCall is required for testing purposes") + + test.wantAuthcodeExchangeCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens(t, + test.wantAuthcodeExchangeCall.performedByUpstreamName, + idpdiscoveryv1alpha1.IDPTypeGitHub, + test.wantAuthcodeExchangeCall.args, + ) + + require.Equal(t, http.StatusSeeOther, rsp.Code) + testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), "") + require.Empty(t, rsp.Body.String()) + + require.Len(t, rsp.Header().Values("Location"), 1) + require.NotEmpty(t, test.wantRedirectLocationRegexp, "wantRedirectLocationRegexp is required for testing purposes") + oidctestutil.RequireAuthCodeRegexpMatch( + t, + rsp.Header().Get("Location"), + test.wantRedirectLocationRegexp, + kubeClient, + secrets, + oauthStore, + test.wantDownstreamGrantedScopes, + test.wantDownstreamIDTokenSubject, + test.wantDownstreamIDTokenUsername, + test.wantDownstreamIDTokenGroups, + test.wantDownstreamRequestedScopes, + test.wantDownstreamPKCEChallenge, + test.wantDownstreamPKCEChallengeMethod, + test.wantDownstreamNonce, + test.wantDownstreamClientID, + downstreamRedirectURI, + test.wantDownstreamCustomSessionData, + test.wantDownstreamAdditionalClaims, + ) + }) + } +} + +func happyGitHubUpstream() *oidctestutil.TestUpstreamGitHubIdentityProviderBuilder { + return oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder(). + WithName(githubIDPName). + WithResourceUID(githubIDPResourceUID). + WithClientID("some-client-id"). + WithScopes([]string{"these", "scopes", "appear", "unused"}) +} diff --git a/internal/federationdomain/endpoints/callback/callback_handler_test.go b/internal/federationdomain/endpoints/callback/callback_handler_test.go index d0aa3a6b8..019a32223 100644 --- a/internal/federationdomain/endpoints/callback/callback_handler_test.go +++ b/internal/federationdomain/endpoints/callback/callback_handler_test.go @@ -20,6 +20,7 @@ import ( "k8s.io/client-go/kubernetes/fake" configv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" supervisorfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" "go.pinniped.dev/internal/federationdomain/endpoints/jwks" "go.pinniped.dev/internal/federationdomain/oidc" @@ -92,7 +93,6 @@ var ( happyDownstreamRequestParamsQueryForDynamicClient = shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": downstreamDynamicClientID}, ) - happyDownstreamRequestParamsForDynamicClient = happyDownstreamRequestParamsQueryForDynamicClient.Encode() happyDownstreamCustomSessionData = &psession.CustomSessionData{ Username: oidcUpstreamUsername, @@ -107,6 +107,7 @@ var ( UpstreamSubject: oidcUpstreamSubject, }, } + happyDownstreamCustomSessionDataWithUsernameAndGroups = func(wantDownstreamUsername, wantUpstreamUsername string, wantUpstreamGroups []string) *psession.CustomSessionData { copyOfCustomSession := *happyDownstreamCustomSessionData copyOfOIDC := *(happyDownstreamCustomSessionData.OIDC) @@ -129,6 +130,14 @@ var ( UpstreamSubject: oidcUpstreamSubject, }, } + + addFullyCapableDynamicClientAndSecretToKubeResources = func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) { + oidcClient, secret := testutil.FullyCapableOIDCClientAndStorageSecret(t, + "some-namespace", downstreamDynamicClientID, downstreamDynamicClientUID, downstreamRedirectURI, nil, + []string{testutil.HashedPassword1AtGoMinCost}, oidcclientvalidator.Validate) + require.NoError(t, supervisorClient.Tracker().Add(oidcClient)) + require.NoError(t, kubeClient.Tracker().Add(secret)) + } ) func TestCallbackEndpoint(t *testing.T) { @@ -153,13 +162,13 @@ func TestCallbackEndpoint(t *testing.T) { happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) happyState := happyUpstreamStateParam().Build(t, happyStateCodec) - happyStateForDynamicClient := happyUpstreamStateParamForDynamicClient().Build(t, happyStateCodec) + happyStateForDynamicClient := happyUpstreamStateParam().WithAuthorizeRequestParams(happyDownstreamRequestParamsQueryForDynamicClient.Encode()).Build(t, happyStateCodec) encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue - happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{ + happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeArgs{ Authcode: happyUpstreamAuthcode, PKCECodeVerifier: oidcpkce.Code(happyDownstreamPKCE), ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), @@ -169,14 +178,6 @@ func TestCallbackEndpoint(t *testing.T) { // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid\+username\+groups&state=` + happyDownstreamState - addFullyCapableDynamicClientAndSecretToKubeResources := func(t *testing.T, supervisorClient *supervisorfake.Clientset, kubeClient *fake.Clientset) { - oidcClient, secret := testutil.FullyCapableOIDCClientAndStorageSecret(t, - "some-namespace", downstreamDynamicClientID, downstreamDynamicClientUID, downstreamRedirectURI, nil, - []string{testutil.HashedPassword1AtGoMinCost}, oidcclientvalidator.Validate) - require.NoError(t, supervisorClient.Tracker().Add(oidcClient)) - require.NoError(t, kubeClient.Tracker().Add(secret)) - } - prefixUsernameAndGroupsPipeline := transformtestutil.NewPrefixingPipeline(t, transformationUsernamePrefix, transformationGroupsPrefix) rejectAuthPipeline := transformtestutil.NewRejectAllAuthPipeline(t) @@ -753,7 +754,7 @@ func TestCallbackEndpoint(t *testing.T) { kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, method: http.MethodGet, path: newRequestPath().WithState( - happyUpstreamStateParamForDynamicClient(). + happyUpstreamStateParam(). WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQueryForDynamicClient, map[string]string{"scope": "openid groups offline_access"}).Encode()). Build(t, happyStateCodec), @@ -783,7 +784,7 @@ func TestCallbackEndpoint(t *testing.T) { kubeResources: addFullyCapableDynamicClientAndSecretToKubeResources, method: http.MethodGet, path: newRequestPath().WithState( - happyUpstreamStateParamForDynamicClient(). + happyUpstreamStateParam(). WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQueryForDynamicClient, map[string]string{"scope": "openid username offline_access"}).Encode()). Build(t, happyStateCodec), @@ -1540,7 +1541,7 @@ func TestCallbackEndpoint(t *testing.T) { } // Configure fosite the same way that the production code would. - // Inject this into our test subject at the last second so we get a fresh storage for every test. + // Inject this into our test subject at the last second, so we get a fresh storage for every test. timeoutsConfiguration := oidc.DefaultOIDCTimeoutsConfiguration() // Use lower minimum required bcrypt cost than we would use in production to keep unit the tests fast. oauthStore := storage.NewKubeStorage(secrets, oidcClientsClient, timeoutsConfiguration, bcrypt.MinCost) @@ -1565,7 +1566,9 @@ func TestCallbackEndpoint(t *testing.T) { if test.wantAuthcodeExchangeCall != nil { test.wantAuthcodeExchangeCall.args.Ctx = reqContext test.idps.RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens(t, - test.wantAuthcodeExchangeCall.performedByUpstreamName, test.wantAuthcodeExchangeCall.args, + test.wantAuthcodeExchangeCall.performedByUpstreamName, + idpdiscoveryv1alpha1.IDPTypeOIDC, + test.wantAuthcodeExchangeCall.args, ) } else { test.idps.RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t) @@ -1636,7 +1639,7 @@ func TestCallbackEndpoint(t *testing.T) { type expectedAuthcodeExchange struct { performedByUpstreamName string - args *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs + args *oidctestutil.ExchangeAuthcodeArgs } type requestPath struct { @@ -1696,12 +1699,6 @@ func happyUpstreamStateParam() *oidctestutil.UpstreamStateParamBuilder { } } -func happyUpstreamStateParamForDynamicClient() *oidctestutil.UpstreamStateParamBuilder { - p := happyUpstreamStateParam() - p.P = happyDownstreamRequestParamsForDynamicClient - return p -} - func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName(happyUpstreamIDPName). diff --git a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go index 015c5fa6c..4efb1c048 100644 --- a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go +++ b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider.go @@ -91,13 +91,28 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) Login( } func (p *FederationDomainResolvedGitHubIdentityProvider) LoginFromCallback( - _ context.Context, - _ string, - _ pkce.Code, - _ nonce.Nonce, - _ string, + ctx context.Context, + authCode string, + _ pkce.Code, // GitHub does not support PKCE, see https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps + _ nonce.Nonce, // GitHub does not support OIDC, therefore there is no ID token that could contain the "nonce". + redirectURI string, ) (*resolvedprovider.Identity, *resolvedprovider.IdentityLoginExtras, error) { - return nil, nil, errors.New("function LoginFromCallback not yet implemented for GitHub IDP") + token, _ := p.Provider.ExchangeAuthcode( + ctx, + authCode, + redirectURI, + ) + + return &resolvedprovider.Identity{ + UpstreamUsername: "some-github-login", + UpstreamGroups: []string{"org1/team1", "org2/team2"}, + DownstreamSubject: "https://github.com?idpName=upstream-github-idp-name&sub=some-github-login", + IDPSpecificSessionData: &psession.GitHubSessionData{ + UpstreamAccessToken: token, + }, + }, + &resolvedprovider.IdentityLoginExtras{}, + nil } func (p *FederationDomainResolvedGitHubIdentityProvider) UpstreamRefresh( diff --git a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go index bf6d98a52..d54b4f57e 100644 --- a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go +++ b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go @@ -15,6 +15,7 @@ import ( "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil/transformtestutil" "go.pinniped.dev/internal/upstreamgithub" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { @@ -58,11 +59,11 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { originalCustomSession := &psession.CustomSessionData{ Username: "fake-username", UpstreamUsername: "fake-upstream-username", - GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"}, + GitHub: &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "fake-upstream-access-token"}}}, } clonedCustomSession := subject.CloneIDPSpecificSessionDataFromSession(originalCustomSession) require.Equal(t, - &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"}, + &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "fake-upstream-access-token"}}}, clonedCustomSession, ) require.NotSame(t, originalCustomSession, clonedCustomSession) @@ -71,11 +72,11 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { Username: "fake-username2", UpstreamUsername: "fake-upstream-username2", } - subject.ApplyIDPSpecificSessionDataToSession(customSessionToBeMutated, &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token2"}) + subject.ApplyIDPSpecificSessionDataToSession(customSessionToBeMutated, &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "OTHER-upstream-access-token"}}}) require.Equal(t, &psession.CustomSessionData{ Username: "fake-username2", UpstreamUsername: "fake-upstream-username2", - GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token2"}, + GitHub: &psession.GitHubSessionData{UpstreamAccessToken: &oidctypes.Token{AccessToken: &oidctypes.AccessToken{Token: "OTHER-upstream-access-token"}}}, }, customSessionToBeMutated) redirectURL, err := subject.UpstreamAuthorizeRedirectURL( diff --git a/internal/federationdomain/upstreamprovider/upsteam_provider.go b/internal/federationdomain/upstreamprovider/upsteam_provider.go index decee99aa..be3229be3 100644 --- a/internal/federationdomain/upstreamprovider/upsteam_provider.go +++ b/internal/federationdomain/upstreamprovider/upsteam_provider.go @@ -168,4 +168,12 @@ type UpstreamGithubIdentityProviderI interface { // Or maybe higher level interface like this? // ExchangeAuthcode(ctx, authcode, redirectURI) (AccessToken, error) // GetUser(ctx, accessToken) (User, error) // in this case User would include team and org info + + // ExchangeAuthcode performs an upstream GitHub authorization code exchange. + // Returns the raw access token. The access token expiry is not known. + ExchangeAuthcode( + ctx context.Context, + authcode string, + redirectURI string, + ) (string, error) } diff --git a/internal/testutil/oidctestutil/expected_upstream_state_param.go b/internal/testutil/oidctestutil/expected_upstream_state_param.go index f3a0bcb96..72ff6398d 100644 --- a/internal/testutil/oidctestutil/expected_upstream_state_param.go +++ b/internal/testutil/oidctestutil/expected_upstream_state_param.go @@ -8,6 +8,8 @@ import ( "github.com/gorilla/securecookie" "github.com/stretchr/testify/require" + + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" ) // ExpectedUpstreamStateParamFormat is a separate type from the production code to ensure that the state @@ -52,8 +54,13 @@ func (b *UpstreamStateParamBuilder) WithPKCE(pkce string) *UpstreamStateParamBui return b } -func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType string) *UpstreamStateParamBuilder { - b.T = upstreamIDPType +func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType idpdiscoveryv1alpha1.IDPType) *UpstreamStateParamBuilder { + b.T = string(upstreamIDPType) + return b +} + +func (b *UpstreamStateParamBuilder) WithUpstreamIDPName(upstreamIDPName string) *UpstreamStateParamBuilder { + b.U = upstreamIDPName return b } diff --git a/internal/testutil/oidctestutil/testgithubprovider.go b/internal/testutil/oidctestutil/testgithubprovider.go index b28a4487f..20b97ad93 100644 --- a/internal/testutil/oidctestutil/testgithubprovider.go +++ b/internal/testutil/oidctestutil/testgithubprovider.go @@ -4,6 +4,8 @@ package oidctestutil import ( + "context" + "k8s.io/apimachinery/pkg/types" "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" @@ -22,6 +24,10 @@ type TestUpstreamGitHubIdentityProviderBuilder struct { groupNameAttribute v1alpha1.GitHubGroupNameAttribute allowedOrganizations []string authorizationURL string + + // Assertions stuff + authcodeExchangeErr error + accessToken string } func (u *TestUpstreamGitHubIdentityProviderBuilder) WithName(value string) *TestUpstreamGitHubIdentityProviderBuilder { @@ -69,6 +75,16 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthorizationURL(value s return u } +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAccessToken(token string) *TestUpstreamGitHubIdentityProviderBuilder { + u.accessToken = token + return u +} + +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithEmptyAccessToken() *TestUpstreamGitHubIdentityProviderBuilder { + u.accessToken = "" + return u +} + func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubIdentityProvider { if u.displayNameForFederationDomain == "" { // default it to the CR name @@ -89,6 +105,13 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubI GroupNameAttribute: u.groupNameAttribute, AllowedOrganizations: u.allowedOrganizations, AuthorizationURL: u.authorizationURL, + + ExchangeAuthcodeFunc: func(ctx context.Context, authcode string) (string, error) { + if u.authcodeExchangeErr != nil { + return "", u.authcodeExchangeErr + } + return u.accessToken, nil + }, } } @@ -107,6 +130,16 @@ type TestUpstreamGitHubIdentityProvider struct { GroupNameAttribute v1alpha1.GitHubGroupNameAttribute AllowedOrganizations []string AuthorizationURL string + + authcodeExchangeErr error + + ExchangeAuthcodeFunc func( + ctx context.Context, + authcode string, + ) (string, error) + + exchangeAuthcodeCallCount int + exchangeAuthcodeArgs []*ExchangeAuthcodeArgs } var _ upstreamprovider.UpstreamGithubIdentityProviderI = &TestUpstreamGitHubIdentityProvider{} @@ -142,3 +175,31 @@ func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() []string func (u *TestUpstreamGitHubIdentityProvider) GetAuthorizationURL() string { return u.AuthorizationURL } + +func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcode( + ctx context.Context, + authcode string, + redirectURI string, +) (string, error) { + if u.exchangeAuthcodeArgs == nil { + u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0) + } + u.exchangeAuthcodeCallCount++ + u.exchangeAuthcodeArgs = append(u.exchangeAuthcodeArgs, &ExchangeAuthcodeArgs{ + Ctx: ctx, + Authcode: authcode, + RedirectURI: redirectURI, + }) + return u.ExchangeAuthcodeFunc(ctx, authcode) +} + +func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeCallCount() int { + return u.exchangeAuthcodeCallCount +} + +func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeArgs(call int) *ExchangeAuthcodeArgs { + if u.exchangeAuthcodeArgs == nil { + u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0) + } + return u.exchangeAuthcodeArgs[call] +} diff --git a/internal/testutil/oidctestutil/testoidcprovider.go b/internal/testutil/oidctestutil/testoidcprovider.go index 489f6304c..9530db42e 100644 --- a/internal/testutil/oidctestutil/testoidcprovider.go +++ b/internal/testutil/oidctestutil/testoidcprovider.go @@ -18,9 +18,9 @@ import ( oidcpkce "go.pinniped.dev/pkg/oidcclient/pkce" ) -// ExchangeAuthcodeAndValidateTokenArgs is used to spy on calls to +// ExchangeAuthcodeArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). -type ExchangeAuthcodeAndValidateTokenArgs struct { +type ExchangeAuthcodeArgs struct { Ctx context.Context Authcode string PKCECodeVerifier oidcpkce.Code @@ -101,7 +101,7 @@ type TestUpstreamOIDCIdentityProvider struct { // Fields for tracking actual calls make to mock functions. exchangeAuthcodeAndValidateTokensCallCount int - exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeArgs passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs performRefreshCallCount int @@ -180,10 +180,10 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( redirectURI string, ) (*oidctypes.Token, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { - u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeArgs, 0) } u.exchangeAuthcodeAndValidateTokensCallCount++ - u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ + u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeArgs{ Ctx: ctx, Authcode: authcode, PKCECodeVerifier: pkceCodeVerifier, @@ -197,9 +197,9 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCall return u.exchangeAuthcodeAndValidateTokensCallCount } -func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeArgs { if u.exchangeAuthcodeAndValidateTokensArgs == nil { - u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeArgs, 0) } return u.exchangeAuthcodeAndValidateTokensArgs[call] } diff --git a/internal/testutil/testidplister/testidplister.go b/internal/testutil/testidplister/testidplister.go index 13b94b29f..2823a5188 100644 --- a/internal/testutil/testidplister/testidplister.go +++ b/internal/testutil/testidplister/testidplister.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/resolvedprovider" "go.pinniped.dev/internal/federationdomain/resolvedprovider/resolvedgithub" @@ -266,38 +267,56 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPasswordCredentialsG func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens( t *testing.T, expectedPerformedByUpstreamName string, - expectedArgs *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs, + expectedPerformedByUpstreamType idpdiscoveryv1alpha1.IDPType, + expectedArgs *oidctestutil.ExchangeAuthcodeArgs, ) { t.Helper() - var actualArgs *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs + var actualArgs *oidctestutil.ExchangeAuthcodeArgs var actualNameOfUpstreamWhichMadeCall string - actualCallCountAcrossAllOIDCUpstreams := 0 + var actualTypeOfUpstreamWhichMadeCall idpdiscoveryv1alpha1.IDPType + actualCallCountAcrossAllOIDCAndGitHubUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.ExchangeAuthcodeAndValidateTokensCallCount() - actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + actualCallCountAcrossAllOIDCAndGitHubUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualTypeOfUpstreamWhichMadeCall = idpdiscoveryv1alpha1.IDPTypeOIDC actualArgs = upstreamOIDC.ExchangeAuthcodeAndValidateTokensArgs(0) } } - require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, - "should have been exactly one call to ExchangeAuthcodeAndValidateTokens() by all OIDC upstreams", + for _, upstreamGitHub := range b.upstreamGitHubIdentityProviders { + callCountOnThisUpstream := upstreamGitHub.ExchangeAuthcodeCallCount() + actualCallCountAcrossAllOIDCAndGitHubUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamGitHub.Name + actualTypeOfUpstreamWhichMadeCall = idpdiscoveryv1alpha1.IDPTypeGitHub + actualArgs = upstreamGitHub.ExchangeAuthcodeArgs(0) + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCAndGitHubUpstreams, + "expected exactly one call to (OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount()", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, - "ExchangeAuthcodeAndValidateTokens() was called on the wrong OIDC upstream", + "(OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount() was called on the wrong upstream name", + ) + require.Equal(t, expectedPerformedByUpstreamType, actualTypeOfUpstreamWhichMadeCall, + "(OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount() was called on the wrong upstream type", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t *testing.T) { t.Helper() - actualCallCountAcrossAllOIDCUpstreams := 0 + actualCallCount := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.ExchangeAuthcodeAndValidateTokensCallCount() + actualCallCount += upstreamOIDC.ExchangeAuthcodeAndValidateTokensCallCount() + } + for _, upstreamGitHub := range b.upstreamGitHubIdentityProviders { + actualCallCount += upstreamGitHub.ExchangeAuthcodeCallCount() } - require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, - "expected exactly zero calls to ExchangeAuthcodeAndValidateTokens()", + require.Equal(t, 0, actualCallCount, + "expected exactly zero calls to (OIDC) ExchangeAuthcodeAndValidateTokensCallCount() or (GitHub) ExchangeAuthcodeCallCount()", ) } diff --git a/internal/upstreamgithub/upstreamgithub.go b/internal/upstreamgithub/upstreamgithub.go index 52e1120b7..f76a02aa5 100644 --- a/internal/upstreamgithub/upstreamgithub.go +++ b/internal/upstreamgithub/upstreamgithub.go @@ -5,6 +5,7 @@ package upstreamgithub import ( + "context" "net/http" "golang.org/x/oauth2" @@ -87,6 +88,11 @@ func (p *Provider) GetAuthorizationURL() string { return p.c.OAuth2Config.Endpoint.AuthURL } +func (p *Provider) ExchangeAuthcode(_ context.Context, _ string, _ string) (string, error) { + //TODO implement me + panic("implement me") +} + // GetConfig returns the config. This is not part of the interface and is mostly just for testing. func (p *Provider) GetConfig() ProviderConfig { return p.c