Files
pinniped/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go
2024-05-22 21:21:45 -05:00

251 lines
9.2 KiB
Go

// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package resolvedgithub
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1"
"go.pinniped.dev/internal/federationdomain/resolvedprovider"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/psession"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/testutil/oidctestutil"
"go.pinniped.dev/internal/testutil/transformtestutil"
"go.pinniped.dev/internal/upstreamgithub"
)
func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) {
transforms := transformtestutil.NewRejectAllAuthPipeline(t)
provider := upstreamgithub.New(upstreamgithub.ProviderConfig{
Name: "fake-provider-config",
ResourceUID: "fake-resource-uid",
APIBaseURL: "https://fake-api-host.com",
UsernameAttribute: idpv1alpha1.GitHubUsernameID,
GroupNameAttribute: idpv1alpha1.GitHubUseTeamSlugForGroupName,
AllowedOrganizations: setutil.NewCaseInsensitiveSet("org1", "org2"),
HttpClient: nil, // not needed yet for this test
OAuth2Config: &oauth2.Config{
ClientID: "fake-client-id",
ClientSecret: "fake-client-secret",
Scopes: []string{"read:user", "read:org"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://fake-authorization-url",
DeviceAuthURL: "",
TokenURL: "https://fake-token-url",
AuthStyle: oauth2.AuthStyleInParams,
},
},
})
subject := FederationDomainResolvedGitHubIdentityProvider{
DisplayName: "fake-display-name",
Provider: provider,
SessionProviderType: psession.ProviderTypeGitHub,
Transforms: transforms,
}
require.Equal(t, "fake-display-name", subject.GetDisplayName())
require.Equal(t, provider, subject.GetProvider())
require.Equal(t, psession.ProviderTypeGitHub, subject.GetSessionProviderType())
require.Equal(t, idpdiscoveryv1alpha1.IDPTypeGitHub, subject.GetIDPDiscoveryType())
require.Equal(t, []idpdiscoveryv1alpha1.IDPFlow{idpdiscoveryv1alpha1.IDPFlowBrowserAuthcode}, subject.GetIDPDiscoveryFlows())
require.Equal(t, transforms, subject.GetTransforms())
originalCustomSession := &psession.CustomSessionData{
Username: "fake-username",
UpstreamUsername: "fake-upstream-username",
GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"},
}
clonedCustomSession := subject.CloneIDPSpecificSessionDataFromSession(originalCustomSession)
require.Equal(t,
&psession.GitHubSessionData{UpstreamAccessToken: "fake-upstream-access-token"},
clonedCustomSession,
)
require.NotSame(t, originalCustomSession, clonedCustomSession)
customSessionToBeMutated := &psession.CustomSessionData{
Username: "fake-username2",
UpstreamUsername: "fake-upstream-username2",
}
subject.ApplyIDPSpecificSessionDataToSession(customSessionToBeMutated, &psession.GitHubSessionData{UpstreamAccessToken: "OTHER-upstream-access-token"})
require.Equal(t, &psession.CustomSessionData{
Username: "fake-username2",
UpstreamUsername: "fake-upstream-username2",
GitHub: &psession.GitHubSessionData{UpstreamAccessToken: "OTHER-upstream-access-token"},
}, customSessionToBeMutated)
redirectURL, err := subject.UpstreamAuthorizeRedirectURL(
&resolvedprovider.UpstreamAuthorizeRequestState{
EncodedStateParam: "encodedStateParam12345",
PKCE: "pkce6789",
Nonce: "nonce1289",
},
"https://localhost/fake/path",
)
require.NoError(t, err)
// Note that GitHub does not require (or document) the standard response_type=code param, but in manual testing
// of GitHub authorize endpoint, it seems to ignore the param. The oauth2 package wants to add the param, so
// we will let it.
require.Equal(t,
"https://fake-authorization-url?"+
"client_id=fake-client-id&"+
"redirect_uri=https%3A%2F%2Flocalhost%2Ffake%2Fpath%2Fcallback&"+
"response_type=code&"+
"scope=read%3Auser+read%3Aorg&"+
"state=encodedStateParam12345",
redirectURL,
)
}
func TestLoginFromCallback(t *testing.T) {
uniqueCtx := context.WithValue(context.Background(), "some-unique-key", "some-value") //nolint:staticcheck // okay to use string key for test
tests := []struct {
name string
provider *oidctestutil.TestUpstreamGitHubIdentityProvider
idpDisplayName string
authcode string
redirectURI string
wantExchangeAuthcodeCall bool
wantExchangeAuthcodeArgs *oidctestutil.ExchangeAuthcodeArgs
wantGetUserCall bool
wantGetUserArgs *oidctestutil.GetUserArgs
wantIdentity *resolvedprovider.Identity
wantExtras *resolvedprovider.IdentityLoginExtras
wantErr string
}{
{
name: "happy path",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithAccessToken("fake-access-token").
WithUser(&upstreamprovider.GitHubUser{
Username: "fake-username",
Groups: []string{"fake-group1", "fake-group2"},
DownstreamSubject: "https://fake-downstream-subject",
}).
Build(),
idpDisplayName: "fake-display-name",
authcode: "fake-authcode",
redirectURI: "https://fake-redirect-uri",
wantExchangeAuthcodeCall: true,
wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{
Ctx: uniqueCtx,
Authcode: "fake-authcode",
RedirectURI: "https://fake-redirect-uri",
},
wantGetUserCall: true,
wantGetUserArgs: &oidctestutil.GetUserArgs{
Ctx: uniqueCtx,
AccessToken: "fake-access-token",
IDPDisplayName: "fake-display-name",
},
wantIdentity: &resolvedprovider.Identity{
UpstreamUsername: "fake-username",
UpstreamGroups: []string{"fake-group1", "fake-group2"},
DownstreamSubject: "https://fake-downstream-subject",
IDPSpecificSessionData: &psession.GitHubSessionData{
UpstreamAccessToken: "fake-access-token",
},
},
wantExtras: &resolvedprovider.IdentityLoginExtras{},
},
{
name: "error while exchanging authcode",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithAuthcodeExchangeError(errors.New("fake authcode exchange error")).
Build(),
idpDisplayName: "fake-display-name",
authcode: "fake-authcode",
redirectURI: "https://fake-redirect-uri",
wantExchangeAuthcodeCall: true,
wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{
Ctx: uniqueCtx,
Authcode: "fake-authcode",
RedirectURI: "https://fake-redirect-uri",
},
wantGetUserCall: false,
wantIdentity: nil,
wantExtras: nil,
wantErr: "failed to exchange authcode using GitHub API: fake authcode exchange error: fake authcode exchange error",
},
{
name: "error while getting user info",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithAccessToken("fake-access-token").
WithGetUserError(errors.New("fake user info error")).
Build(),
idpDisplayName: "fake-display-name",
authcode: "fake-authcode",
redirectURI: "https://fake-redirect-uri",
wantExchangeAuthcodeCall: true,
wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{
Ctx: uniqueCtx,
Authcode: "fake-authcode",
RedirectURI: "https://fake-redirect-uri",
},
wantGetUserCall: true,
wantGetUserArgs: &oidctestutil.GetUserArgs{
Ctx: uniqueCtx,
AccessToken: "fake-access-token",
IDPDisplayName: "fake-display-name",
},
wantIdentity: nil,
wantExtras: nil,
wantErr: "failed to get user info from GitHub API: fake user info error: fake user info error",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
subject := FederationDomainResolvedGitHubIdentityProvider{
DisplayName: test.idpDisplayName,
Provider: test.provider,
SessionProviderType: psession.ProviderTypeGitHub,
Transforms: transformtestutil.NewRejectAllAuthPipeline(t),
}
identity, loginExtras, err := subject.LoginFromCallback(uniqueCtx,
test.authcode,
"pkce-will-be-ignored",
"nonce-will-be-ignored",
test.redirectURI,
)
if test.wantExchangeAuthcodeCall {
require.Equal(t, 1, test.provider.ExchangeAuthcodeCallCount())
require.Equal(t, test.wantExchangeAuthcodeArgs, test.provider.ExchangeAuthcodeArgs(0))
} else {
require.Zero(t, test.provider.ExchangeAuthcodeCallCount())
}
if test.wantGetUserCall {
require.Equal(t, 1, test.provider.GetUserCallCount())
require.Equal(t, test.wantGetUserArgs, test.provider.GetUserArgs(0))
} else {
require.Zero(t, test.provider.GetUserCallCount())
}
if test.wantErr == "" {
require.NoError(t, err)
} else {
errAsResponder, ok := err.(httperr.Responder)
require.True(t, ok)
require.EqualError(t, errAsResponder, test.wantErr)
}
require.Equal(t, test.wantExtras, loginExtras)
require.Equal(t, test.wantIdentity, identity)
})
}
}