mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-05 04:56:11 +00:00
265 lines
8.9 KiB
Go
265 lines
8.9 KiB
Go
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package oidctestutil
|
|
|
|
import (
|
|
"context"
|
|
|
|
"k8s.io/apimachinery/pkg/types"
|
|
|
|
idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
|
|
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
|
"go.pinniped.dev/internal/idtransform"
|
|
"go.pinniped.dev/internal/setutil"
|
|
)
|
|
|
|
// ExchangeAuthcodeArgs is used to spy on calls to
|
|
// TestUpstreamGitHubIdentityProvider.ExchangeAuthcodeFunc().
|
|
type ExchangeAuthcodeArgs struct {
|
|
Ctx context.Context
|
|
Authcode string
|
|
RedirectURI string
|
|
}
|
|
|
|
// GetUserArgs is used to spy on calls to
|
|
// TestUpstreamGitHubIdentityProvider.GetUserFunc().
|
|
type GetUserArgs struct {
|
|
Ctx context.Context
|
|
AccessToken string
|
|
IDPDisplayName string
|
|
}
|
|
|
|
type TestUpstreamGitHubIdentityProviderBuilder struct {
|
|
name string
|
|
resourceUID types.UID
|
|
clientID string
|
|
scopes []string
|
|
displayNameForFederationDomain string
|
|
transformsForFederationDomain *idtransform.TransformationPipeline
|
|
usernameAttribute idpv1alpha1.GitHubUsernameAttribute
|
|
groupNameAttribute idpv1alpha1.GitHubGroupNameAttribute
|
|
allowedOrganizations *setutil.CaseInsensitiveSet
|
|
authorizationURL string
|
|
authcodeExchangeErr error
|
|
accessToken string
|
|
getUserErr error
|
|
getUserUser *upstreamprovider.GitHubUser
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithName(value string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.name = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithResourceUID(value types.UID) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.resourceUID = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithClientID(value string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.clientID = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithScopes(value []string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.scopes = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithDisplayNameForFederationDomain(value string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.displayNameForFederationDomain = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithUsernameAttribute(value idpv1alpha1.GitHubUsernameAttribute) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.usernameAttribute = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGroupNameAttribute(value idpv1alpha1.GitHubGroupNameAttribute) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.groupNameAttribute = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value *setutil.CaseInsensitiveSet) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.allowedOrganizations = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthorizationURL(value string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.authorizationURL = value
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAccessToken(token string) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.accessToken = token
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAuthcodeExchangeError(err error) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.authcodeExchangeErr = err
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithUser(user *upstreamprovider.GitHubUser) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.getUserUser = user
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGetUserError(err error) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.getUserErr = err
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithTransformsForFederationDomain(transforms *idtransform.TransformationPipeline) *TestUpstreamGitHubIdentityProviderBuilder {
|
|
u.transformsForFederationDomain = transforms
|
|
return u
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProviderBuilder) Build() *TestUpstreamGitHubIdentityProvider {
|
|
if u.displayNameForFederationDomain == "" {
|
|
// default it to the CR name
|
|
u.displayNameForFederationDomain = u.name
|
|
}
|
|
if u.transformsForFederationDomain == nil {
|
|
// default to an empty pipeline
|
|
u.transformsForFederationDomain = idtransform.NewTransformationPipeline()
|
|
}
|
|
return &TestUpstreamGitHubIdentityProvider{
|
|
Name: u.name,
|
|
ClientID: u.clientID,
|
|
ResourceUID: u.resourceUID,
|
|
Scopes: u.scopes,
|
|
DisplayNameForFederationDomain: u.displayNameForFederationDomain,
|
|
TransformsForFederationDomain: u.transformsForFederationDomain,
|
|
UsernameAttribute: u.usernameAttribute,
|
|
GroupNameAttribute: u.groupNameAttribute,
|
|
AllowedOrganizations: u.allowedOrganizations,
|
|
AuthorizationURL: u.authorizationURL,
|
|
GetUserFunc: func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) {
|
|
if u.getUserErr != nil {
|
|
return nil, u.getUserErr
|
|
}
|
|
return u.getUserUser, nil
|
|
},
|
|
ExchangeAuthcodeFunc: func(ctx context.Context, authcode string) (string, error) {
|
|
if u.authcodeExchangeErr != nil {
|
|
return "", u.authcodeExchangeErr
|
|
}
|
|
return u.accessToken, nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func NewTestUpstreamGitHubIdentityProviderBuilder() *TestUpstreamGitHubIdentityProviderBuilder {
|
|
return &TestUpstreamGitHubIdentityProviderBuilder{}
|
|
}
|
|
|
|
type TestUpstreamGitHubIdentityProvider struct {
|
|
Name string
|
|
ClientID string
|
|
ResourceUID types.UID
|
|
Scopes []string
|
|
DisplayNameForFederationDomain string
|
|
TransformsForFederationDomain *idtransform.TransformationPipeline
|
|
UsernameAttribute idpv1alpha1.GitHubUsernameAttribute
|
|
GroupNameAttribute idpv1alpha1.GitHubGroupNameAttribute
|
|
AllowedOrganizations *setutil.CaseInsensitiveSet
|
|
AuthorizationURL string
|
|
GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error)
|
|
ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error)
|
|
|
|
// Fields for tracking actual calls make to mock functions.
|
|
exchangeAuthcodeCallCount int
|
|
exchangeAuthcodeArgs []*ExchangeAuthcodeArgs
|
|
getUserCallCount int
|
|
getUserArgs []*GetUserArgs
|
|
}
|
|
|
|
var _ upstreamprovider.UpstreamGithubIdentityProviderI = &TestUpstreamGitHubIdentityProvider{}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetResourceUID() types.UID {
|
|
return u.ResourceUID
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetResourceName() string {
|
|
return u.Name
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetScopes() []string {
|
|
return u.Scopes
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetClientID() string {
|
|
return u.ClientID
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetUsernameAttribute() idpv1alpha1.GitHubUsernameAttribute {
|
|
return u.UsernameAttribute
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetGroupNameAttribute() idpv1alpha1.GitHubGroupNameAttribute {
|
|
return u.GroupNameAttribute
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
|
|
return u.AllowedOrganizations
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetAuthorizationURL() string {
|
|
return u.AuthorizationURL
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcode(
|
|
ctx context.Context,
|
|
authcode string,
|
|
redirectURI string,
|
|
) (string, error) {
|
|
if u.exchangeAuthcodeArgs == nil {
|
|
u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0)
|
|
}
|
|
u.exchangeAuthcodeCallCount++
|
|
u.exchangeAuthcodeArgs = append(u.exchangeAuthcodeArgs, &ExchangeAuthcodeArgs{
|
|
Ctx: ctx,
|
|
Authcode: authcode,
|
|
RedirectURI: redirectURI,
|
|
})
|
|
return u.ExchangeAuthcodeFunc(ctx, authcode)
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeCallCount() int {
|
|
return u.exchangeAuthcodeCallCount
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) ExchangeAuthcodeArgs(call int) *ExchangeAuthcodeArgs {
|
|
if u.exchangeAuthcodeArgs == nil {
|
|
u.exchangeAuthcodeArgs = make([]*ExchangeAuthcodeArgs, 0)
|
|
}
|
|
return u.exchangeAuthcodeArgs[call]
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetUser(ctx context.Context, accessToken string, idpDisplayName string) (*upstreamprovider.GitHubUser, error) {
|
|
if u.getUserArgs == nil {
|
|
u.getUserArgs = make([]*GetUserArgs, 0)
|
|
}
|
|
u.getUserCallCount++
|
|
u.getUserArgs = append(u.getUserArgs, &GetUserArgs{
|
|
Ctx: ctx,
|
|
AccessToken: accessToken,
|
|
IDPDisplayName: idpDisplayName,
|
|
})
|
|
return u.GetUserFunc(ctx, accessToken)
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetUserCallCount() int {
|
|
return u.getUserCallCount
|
|
}
|
|
|
|
func (u *TestUpstreamGitHubIdentityProvider) GetUserArgs(call int) *GetUserArgs {
|
|
if u.getUserArgs == nil {
|
|
u.getUserArgs = make([]*GetUserArgs, 0)
|
|
}
|
|
return u.getUserArgs[call]
|
|
}
|