Files
pinniped/internal/testutil/oidctestutil/testgithubprovider.go
2024-06-11 10:16:18 -05:00

265 lines
8.9 KiB
Go

// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidctestutil
import (
"context"
"k8s.io/apimachinery/pkg/types"
idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/setutil"
)
// ExchangeAuthcodeArgs is used to spy on calls to
// TestUpstreamGitHubIdentityProvider.ExchangeAuthcodeFunc().
type ExchangeAuthcodeArgs struct {
Ctx context.Context
Authcode string
RedirectURI string
}
// GetUserArgs is used to spy on calls to
// TestUpstreamGitHubIdentityProvider.GetUserFunc().
type GetUserArgs struct {
Ctx context.Context
AccessToken string
IDPDisplayName string
}
type TestUpstreamGitHubIdentityProviderBuilder struct {
name string
resourceUID types.UID
clientID string
scopes []string
displayNameForFederationDomain string
transformsForFederationDomain *idtransform.TransformationPipeline
usernameAttribute idpv1alpha1.GitHubUsernameAttribute
groupNameAttribute idpv1alpha1.GitHubGroupNameAttribute
allowedOrganizations *setutil.CaseInsensitiveSet
authorizationURL string
authcodeExchangeErr error
accessToken string
getUserErr error
getUserUser *upstreamprovider.GitHubUser
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithName(value string) *TestUpstreamGitHubIdentityProviderBuilder {
u.name = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithResourceUID(value types.UID) *TestUpstreamGitHubIdentityProviderBuilder {
u.resourceUID = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithClientID(value string) *TestUpstreamGitHubIdentityProviderBuilder {
u.clientID = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithScopes(value []string) *TestUpstreamGitHubIdentityProviderBuilder {
u.scopes = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithDisplayNameForFederationDomain(value string) *TestUpstreamGitHubIdentityProviderBuilder {
u.displayNameForFederationDomain = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithUsernameAttribute(value idpv1alpha1.GitHubUsernameAttribute) *TestUpstreamGitHubIdentityProviderBuilder {
u.usernameAttribute = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGroupNameAttribute(value idpv1alpha1.GitHubGroupNameAttribute) *TestUpstreamGitHubIdentityProviderBuilder {
u.groupNameAttribute = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value *setutil.CaseInsensitiveSet) *TestUpstreamGitHubIdentityProviderBuilder {
u.allowedOrganizations = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthorizationURL(value string) *TestUpstreamGitHubIdentityProviderBuilder {
u.authorizationURL = value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAccessToken(token string) *TestUpstreamGitHubIdentityProviderBuilder {
u.accessToken = token
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthcodeExchangeError(err error) *TestUpstreamGitHubIdentityProviderBuilder {
u.authcodeExchangeErr = err
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithUser(user *upstreamprovider.GitHubUser) *TestUpstreamGitHubIdentityProviderBuilder {
u.getUserUser = user
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGetUserError(err error) *TestUpstreamGitHubIdentityProviderBuilder {
u.getUserErr = err
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithTransformsForFederationDomain(transforms *idtransform.TransformationPipeline) *TestUpstreamGitHubIdentityProviderBuilder {
u.transformsForFederationDomain = transforms
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubIdentityProvider {
if u.displayNameForFederationDomain == "" {
// default it to the CR name
u.displayNameForFederationDomain = u.name
}
if u.transformsForFederationDomain == nil {
// default to an empty pipeline
u.transformsForFederationDomain = idtransform.NewTransformationPipeline()
}
return &TestUpstreamGitHubIdentityProvider{
Name: u.name,
ClientID: u.clientID,
ResourceUID: u.resourceUID,
Scopes: u.scopes,
DisplayNameForFederationDomain: u.displayNameForFederationDomain,
TransformsForFederationDomain: u.transformsForFederationDomain,
UsernameAttribute: u.usernameAttribute,
GroupNameAttribute: u.groupNameAttribute,
AllowedOrganizations: u.allowedOrganizations,
AuthorizationURL: u.authorizationURL,
GetUserFunc: func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) {
if u.getUserErr != nil {
return nil, u.getUserErr
}
return u.getUserUser, nil
},
ExchangeAuthcodeFunc: func(ctx context.Context, authcode string) (string, error) {
if u.authcodeExchangeErr != nil {
return "", u.authcodeExchangeErr
}
return u.accessToken, nil
},
}
}
func NewTestUpstreamGitHubIdentityProviderBuilder() *TestUpstreamGitHubIdentityProviderBuilder {
return &TestUpstreamGitHubIdentityProviderBuilder{}
}
type TestUpstreamGitHubIdentityProvider struct {
Name string
ClientID string
ResourceUID types.UID
Scopes []string
DisplayNameForFederationDomain string
TransformsForFederationDomain *idtransform.TransformationPipeline
UsernameAttribute idpv1alpha1.GitHubUsernameAttribute
GroupNameAttribute idpv1alpha1.GitHubGroupNameAttribute
AllowedOrganizations *setutil.CaseInsensitiveSet
AuthorizationURL string
GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error)
ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error)
// Fields for tracking actual calls make to mock functions.
exchangeAuthcodeCallCount int
exchangeAuthcodeArgs []*ExchangeAuthcodeArgs
getUserCallCount int
getUserArgs []*GetUserArgs
}
var _ upstreamprovider.UpstreamGithubIdentityProviderI = &TestUpstreamGitHubIdentityProvider{}
func (u *TestUpstreamGitHubIdentityProvider) GetResourceUID() types.UID {
return u.ResourceUID
}
func (u *TestUpstreamGitHubIdentityProvider) GetResourceName() string {
return u.Name
}
func (u *TestUpstreamGitHubIdentityProvider) GetScopes() []string {
return u.Scopes
}
func (u *TestUpstreamGitHubIdentityProvider) GetClientID() string {
return u.ClientID
}
func (u *TestUpstreamGitHubIdentityProvider) GetUsernameAttribute() idpv1alpha1.GitHubUsernameAttribute {
return u.UsernameAttribute
}
func (u *TestUpstreamGitHubIdentityProvider) GetGroupNameAttribute() idpv1alpha1.GitHubGroupNameAttribute {
return u.GroupNameAttribute
}
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
return u.AllowedOrganizations
}
func (u *TestUpstreamGitHubIdentityProvider) GetAuthorizationURL() string {
return u.AuthorizationURL
}
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcode(
ctx context.Context,
authcode string,
redirectURI string,
) (string, error) {
if u.exchangeAuthcodeArgs == nil {
u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0)
}
u.exchangeAuthcodeCallCount++
u.exchangeAuthcodeArgs = append(u.exchangeAuthcodeArgs, &ExchangeAuthcodeArgs{
Ctx: ctx,
Authcode: authcode,
RedirectURI: redirectURI,
})
return u.ExchangeAuthcodeFunc(ctx, authcode)
}
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeCallCount() int {
return u.exchangeAuthcodeCallCount
}
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeArgs(call int) *ExchangeAuthcodeArgs {
if u.exchangeAuthcodeArgs == nil {
u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0)
}
return u.exchangeAuthcodeArgs[call]
}
func (u *TestUpstreamGitHubIdentityProvider) GetUser(ctx context.Context, accessToken string, idpDisplayName string) (*upstreamprovider.GitHubUser, error) {
if u.getUserArgs == nil {
u.getUserArgs = make([]*GetUserArgs, 0)
}
u.getUserCallCount++
u.getUserArgs = append(u.getUserArgs, &GetUserArgs{
Ctx: ctx,
AccessToken: accessToken,
IDPDisplayName: idpDisplayName,
})
return u.GetUserFunc(ctx, accessToken)
}
func (u *TestUpstreamGitHubIdentityProvider) GetUserCallCount() int {
return u.getUserCallCount
}
func (u *TestUpstreamGitHubIdentityProvider) GetUserArgs(call int) *GetUserArgs {
if u.getUserArgs == nil {
u.getUserArgs = make([]*GetUserArgs, 0)
}
return u.getUserArgs[call]
}