Make github org comparison case-insensitive, but return original case

Co-authored-by: Joshua Casey <joshuatcasey@gmail.com>
This commit is contained in:
Ryan Richard
2024-05-21 11:57:55 -07:00
committed by Joshua Casey
parent 8923704f3c
commit 8f8db3f542
14 changed files with 240 additions and 73 deletions

View File

@@ -38,6 +38,7 @@ import (
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/upstreamgithub"
)
@@ -317,7 +318,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
RedirectURL: "", // this will be different for each FederationDomain, so we do not set it here
Scopes: []string{"read:user", "read:org"},
},
AllowedOrganizations: upstream.Spec.AllowAuthentication.Organizations.Allowed,
AllowedOrganizations: setutil.NewCaseInsensitiveSet(upstream.Spec.AllowAuthentication.Organizations.Allowed...),
HttpClient: httpClient,
},
)

View File

@@ -41,6 +41,7 @@ import (
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/tlsserver"
"go.pinniped.dev/internal/upstreamgithub"
@@ -406,7 +407,7 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
AllowedOrganizations: []string{"organization1", "org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
HttpClient: nil, // let the test runner populate this for us
},
},
@@ -462,7 +463,8 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
HttpClient: nil, // let the test runner populate this for us
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -531,7 +533,8 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
HttpClient: nil, // let the test runner populate this for us
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -598,7 +601,8 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
HttpClient: nil, // let the test runner populate this for us
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -685,7 +689,7 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
AllowedOrganizations: []string{"organization1", "org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
HttpClient: nil, // let the test runner populate this for us
},
{
@@ -706,7 +710,7 @@ func TestController(t *testing.T) {
RedirectURL: "", // not used
Scopes: []string{"read:user", "read:org"},
},
AllowedOrganizations: []string{"organization1", "org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
HttpClient: nil, // let the test runner populate this for us
},
},

View File

@@ -17,6 +17,7 @@ import (
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/psession"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/testutil/oidctestutil"
"go.pinniped.dev/internal/testutil/transformtestutil"
"go.pinniped.dev/internal/upstreamgithub"
@@ -31,7 +32,7 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) {
APIBaseURL: "https://fake-api-host.com",
UsernameAttribute: idpv1alpha1.GitHubUsernameID,
GroupNameAttribute: idpv1alpha1.GitHubUseTeamSlugForGroupName,
AllowedOrganizations: []string{"org1", "org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("org1", "org2"),
HttpClient: nil, // not needed yet for this test
OAuth2Config: &oauth2.Config{
ClientID: "fake-client-id",

View File

@@ -12,6 +12,7 @@ import (
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
@@ -157,7 +158,7 @@ type UpstreamGithubIdentityProviderI interface {
// and only teams from the listed organizations should be represented as groups for the downstream token.
// If this list is empty, then any user can log in regardless of org membership, and any observable
// teams memberships should be represented as groups for the downstream token.
GetAllowedOrganizations() []string
GetAllowedOrganizations() *setutil.CaseInsensitiveSet
// GetAuthorizationURL returns the authorization URL for the configured GitHub. This will look like:
// https://<spec.githubAPI.host>/login/oauth/authorize

View File

@@ -16,6 +16,7 @@ import (
"k8s.io/apimachinery/pkg/util/sets"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/setutil"
)
const (
@@ -36,8 +37,8 @@ type TeamInfo struct {
type GitHubInterface interface {
GetUserInfo(ctx context.Context) (*UserInfo, error)
GetOrgMembership(ctx context.Context) (sets.Set[string], error)
GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error)
GetOrgMembership(ctx context.Context) ([]string, error)
GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error)
}
type githubClient struct {
@@ -107,7 +108,7 @@ func (g *githubClient) GetUserInfo(ctx context.Context) (*UserInfo, error) {
}
// GetOrgMembership returns an array of the "Login" attributes for all organizations to which the authenticated user belongs.
func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string], error) {
func (g *githubClient) GetOrgMembership(ctx context.Context) ([]string, error) {
const errorPrefix = "error fetching organizations for authenticated user"
organizationLogins := sets.New[string]()
@@ -135,11 +136,11 @@ func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string],
}
plog.Trace("calculated response from GitHub org membership endpoint", "orgs", organizationLogins.UnsortedList())
return organizationLogins, nil
return organizationLogins.UnsortedList(), nil
}
func isOrgAllowed(allowedOrganizations sets.Set[string], login string) bool {
return len(allowedOrganizations) == 0 || allowedOrganizations.Has(login)
func isOrgAllowed(allowedOrganizations *setutil.CaseInsensitiveSet, login string) bool {
return allowedOrganizations.Empty() || allowedOrganizations.ContainsIgnoringCase(login)
}
func buildAndValidateParentTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, error) {
@@ -176,7 +177,7 @@ func buildTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, er
// GetTeamMembership returns a description of each team to which the authenticated user belongs.
// If allowedOrganizations is not empty, will filter the results to only those teams which belong to the allowed organizations.
// Parent teams will also be returned.
func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error) {
func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error) {
const errorPrefix = "error fetching team membership for authenticated user"
teamInfos := sets.New[TeamInfo]()

View File

@@ -12,10 +12,10 @@ import (
"github.com/google/go-github/v62/github"
"github.com/migueleliasweb/go-github-mock/src/mock"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/util/cert"
"go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/testutil/tlsserver"
)
@@ -380,8 +380,7 @@ func TestGetOrgMembership(t *testing.T) {
}
require.NotNil(t, actual)
require.Equal(t, len(actual), len(test.wantOrgs))
require.True(t, actual.HasAll(test.wantOrgs...))
require.ElementsMatch(t, test.wantOrgs, actual)
})
}
}
@@ -394,7 +393,7 @@ func TestGetTeamMembership(t *testing.T) {
httpClient *http.Client
token string
ctx context.Context
allowedOrganizations []string
allowedOrganizations *setutil.CaseInsensitiveSet
wantErr string
wantTeams []TeamInfo
}{
@@ -436,7 +435,7 @@ func TestGetTeamMembership(t *testing.T) {
),
),
token: "some-token",
allowedOrganizations: []string{"alpha", "beta"},
allowedOrganizations: setutil.NewCaseInsensitiveSet("alpha", "beta"),
wantTeams: []TeamInfo{
{
Name: "orgAlpha-team1-name",
@@ -461,7 +460,7 @@ func TestGetTeamMembership(t *testing.T) {
},
},
{
name: "filters by allowedOrganizations",
name: "filters by allowedOrganizations in a case-insensitive way, but preserves case as returned by GitHub API in the result",
httpClient: mock.NewMockedHTTPClient(
mock.WithRequestMatch(
mock.GetUserTeams,
@@ -470,38 +469,38 @@ func TestGetTeamMembership(t *testing.T) {
Name: github.String("team1-name"),
Slug: github.String("team1-slug"),
Organization: &github.Organization{
Login: github.String("alpha"),
Login: github.String("alPhA"),
},
},
{
Name: github.String("team2-name"),
Slug: github.String("team2-slug"),
Organization: &github.Organization{
Login: github.String("beta"),
Login: github.String("bEtA"),
},
},
{
Name: github.String("team3-name"),
Slug: github.String("team3-slug"),
Organization: &github.Organization{
Login: github.String("gamma"),
Login: github.String("gAmmA"),
},
},
},
),
),
token: "some-token",
allowedOrganizations: []string{"alpha", "gamma"},
allowedOrganizations: setutil.NewCaseInsensitiveSet("ALPHA", "gamma"),
wantTeams: []TeamInfo{
{
Name: "team1-name",
Slug: "team1-slug",
Org: "alpha",
Org: "alPhA",
},
{
Name: "team3-name",
Slug: "team3-slug",
Org: "gamma",
Org: "gAmmA",
},
},
},
@@ -623,11 +622,8 @@ func TestGetTeamMembership(t *testing.T) {
},
),
),
token: "some-token",
allowedOrganizations: []string{
"org-with-nested-teams",
"beta",
},
token: "some-token",
allowedOrganizations: setutil.NewCaseInsensitiveSet("org-with-nested-teams", "beta"),
wantTeams: []TeamInfo{
{
Name: "team-name-without-parent",
@@ -677,7 +673,7 @@ func TestGetTeamMembership(t *testing.T) {
),
),
token: "some-token",
allowedOrganizations: []string{"page1-org-name", "page2-org-name"},
allowedOrganizations: setutil.NewCaseInsensitiveSet("page1-org-name", "page2-org-name"),
wantTeams: []TeamInfo{
{
Name: "page1-team-name",
@@ -770,7 +766,7 @@ func TestGetTeamMembership(t *testing.T) {
),
),
token: "does-this-token-work",
allowedOrganizations: []string{"org-login"},
allowedOrganizations: setutil.NewCaseInsensitiveSet("org-login"),
wantTeams: []TeamInfo{
{
Name: "team1-name",
@@ -821,7 +817,7 @@ func TestGetTeamMembership(t *testing.T) {
ctx = test.ctx
}
actual, err := githubClient.GetTeamMembership(ctx, sets.New[string](test.allowedOrganizations...))
actual, err := githubClient.GetTeamMembership(ctx, test.allowedOrganizations)
if test.wantErr != "" {
rt, ok := test.httpClient.Transport.(*mock.EnforceHostRoundTripper)
require.True(t, ok)

View File

@@ -18,8 +18,8 @@ import (
reflect "reflect"
githubclient "go.pinniped.dev/internal/githubclient"
setutil "go.pinniped.dev/internal/setutil"
gomock "go.uber.org/mock/gomock"
sets "k8s.io/apimachinery/pkg/util/sets"
)
// MockGitHubInterface is a mock of GitHubInterface interface.
@@ -46,10 +46,10 @@ func (m *MockGitHubInterface) EXPECT() *MockGitHubInterfaceMockRecorder {
}
// GetOrgMembership mocks base method.
func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) (sets.Set[string], error) {
func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrgMembership", arg0)
ret0, _ := ret[0].(sets.Set[string])
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -61,7 +61,7 @@ func (mr *MockGitHubInterfaceMockRecorder) GetOrgMembership(arg0 any) *gomock.Ca
}
// GetTeamMembership mocks base method.
func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 sets.Set[string]) ([]githubclient.TeamInfo, error) {
func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 *setutil.CaseInsensitiveSet) ([]githubclient.TeamInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTeamMembership", arg0, arg1)
ret0, _ := ret[0].([]githubclient.TeamInfo)

View File

@@ -0,0 +1,43 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package setutil
import (
"strings"
"k8s.io/apimachinery/pkg/util/sets"
"go.pinniped.dev/internal/sliceutil"
)
type CaseInsensitiveSet struct {
lowercasedContents sets.Set[string]
}
func NewCaseInsensitiveSet(items ...string) *CaseInsensitiveSet {
return &CaseInsensitiveSet{
lowercasedContents: sets.New(sliceutil.Map(items, strings.ToLower)...),
}
}
func (s *CaseInsensitiveSet) HasAnyIgnoringCase(items []string) bool {
if s == nil {
return false
}
return s.lowercasedContents.HasAny(sliceutil.Map(items, strings.ToLower)...)
}
func (s *CaseInsensitiveSet) ContainsIgnoringCase(item string) bool {
if s == nil {
return false
}
return s.lowercasedContents.Has(strings.ToLower(item))
}
func (s *CaseInsensitiveSet) Empty() bool {
if s == nil {
return true
}
return s.lowercasedContents.Len() == 0
}

View File

@@ -0,0 +1,34 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package setutil
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCaseInsensitiveSet(t *testing.T) {
var nilSet *CaseInsensitiveSet
require.True(t, nilSet.Empty())
require.False(t, nilSet.HasAnyIgnoringCase([]string{"a", "b"}))
require.False(t, nilSet.HasAnyIgnoringCase(nil))
require.False(t, nilSet.ContainsIgnoringCase("a"))
require.False(t, nilSet.ContainsIgnoringCase("a"))
emptySet := NewCaseInsensitiveSet()
require.True(t, emptySet.Empty())
require.False(t, emptySet.HasAnyIgnoringCase([]string{"a", "b"}))
require.False(t, emptySet.HasAnyIgnoringCase(nil))
require.False(t, emptySet.ContainsIgnoringCase("a"))
require.False(t, emptySet.ContainsIgnoringCase("a"))
set := NewCaseInsensitiveSet("A", "B", "c")
require.False(t, set.Empty())
require.False(t, set.HasAnyIgnoringCase([]string{"x", "y"}))
require.True(t, set.HasAnyIgnoringCase([]string{"a", "x"}))
require.False(t, set.HasAnyIgnoringCase(nil))
require.False(t, set.ContainsIgnoringCase("x"))
require.True(t, set.ContainsIgnoringCase("a"))
}

View File

@@ -0,0 +1,13 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package sliceutil
// Map transforms a slice from an input type I to an output type O using a transform func.
func Map[I, O any](in []I, transform func(I) O) []O {
out := make([]O, len(in))
for i := range in {
out[i] = transform(in[i])
}
return out
}

View File

@@ -0,0 +1,74 @@
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package sliceutil
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestMap(t *testing.T) {
type testCase[I any, O any] struct {
name string
in []I
transformFunc func(I) O
want []O
}
stringStringTests := []testCase[string, string]{
{
name: "downcase func",
in: []string{"Aa", "bB", "CC"},
transformFunc: strings.ToLower,
want: []string{"aa", "bb", "cc"},
},
{
name: "upcase func",
in: []string{"Aa", "bB", "CC"},
transformFunc: strings.ToUpper,
want: []string{"AA", "BB", "CC"},
},
{
name: "when in is nil, then out is an empty slice",
in: nil,
transformFunc: strings.ToUpper,
want: []string{},
},
}
for _, tt := range stringStringTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
actual := Map(tt.in, tt.transformFunc)
require.Equal(t, tt.want, actual)
})
}
stringIntTests := []testCase[string, int]{
{
name: "len func",
in: []string{"Aa", "bBb", "CCcC"},
transformFunc: func(s string) int {
return len(s)
},
want: []int{2, 3, 4},
},
{
name: "index func",
in: []string{"Aab", "bB", "CC"},
transformFunc: func(s string) int {
return strings.Index(s, "b")
},
want: []int{2, 0, -1},
},
}
for _, tt := range stringIntTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
actual := Map(tt.in, tt.transformFunc)
require.Equal(t, tt.want, actual)
})
}
}

View File

@@ -11,6 +11,7 @@ import (
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/setutil"
)
// ExchangeAuthcodeArgs is used to spy on calls to
@@ -38,7 +39,7 @@ type TestUpstreamGitHubIdentityProviderBuilder struct {
transformsForFederationDomain *idtransform.TransformationPipeline
usernameAttribute v1alpha1.GitHubUsernameAttribute
groupNameAttribute v1alpha1.GitHubGroupNameAttribute
allowedOrganizations []string
allowedOrganizations *setutil.CaseInsensitiveSet
authorizationURL string
authcodeExchangeErr error
accessToken string
@@ -81,7 +82,7 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGroupNameAttribute(value
return u
}
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value []string) *TestUpstreamGitHubIdentityProviderBuilder {
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value *setutil.CaseInsensitiveSet) *TestUpstreamGitHubIdentityProviderBuilder {
u.allowedOrganizations = value
return u
}
@@ -159,7 +160,7 @@ type TestUpstreamGitHubIdentityProvider struct {
TransformsForFederationDomain *idtransform.TransformationPipeline
UsernameAttribute v1alpha1.GitHubUsernameAttribute
GroupNameAttribute v1alpha1.GitHubGroupNameAttribute
AllowedOrganizations []string
AllowedOrganizations *setutil.CaseInsensitiveSet
AuthorizationURL string
GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error)
ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error)
@@ -197,7 +198,7 @@ func (u *TestUpstreamGitHubIdentityProvider) GetGroupNameAttribute() v1alpha1.Gi
return u.GroupNameAttribute
}
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() []string {
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
return u.AllowedOrganizations
}

View File

@@ -13,12 +13,12 @@ import (
coreosoidc "github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/federationdomain/downstreamsubject"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/githubclient"
"go.pinniped.dev/internal/setutil"
)
// ProviderConfig holds the active configuration of an upstream GitHub provider.
@@ -35,7 +35,7 @@ type ProviderConfig struct {
GroupNameAttribute supervisoridpv1alpha1.GitHubGroupNameAttribute
// AllowedOrganizations, when empty, means to allow users from all orgs.
AllowedOrganizations []string
AllowedOrganizations *setutil.CaseInsensitiveSet
// HttpClient is a client that can be used to call the GitHub APIs and token endpoint.
// This client should be configured with the user-provided CA bundle and a timeout.
@@ -90,7 +90,7 @@ func (p *Provider) GetGroupNameAttribute() supervisoridpv1alpha1.GitHubGroupName
return p.c.GroupNameAttribute
}
func (p *Provider) GetAllowedOrganizations() []string {
func (p *Provider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
return p.c.AllowedOrganizations
}
@@ -146,13 +146,11 @@ func (p *Provider) GetUser(ctx context.Context, accessToken string, idpDisplayNa
return nil, err
}
allowedOrgs := sets.New[string](p.c.AllowedOrganizations...)
if allowedOrgs.Len() > 0 && allowedOrgs.Intersection(orgMembership).Len() < 1 {
if !p.c.AllowedOrganizations.Empty() && !p.c.AllowedOrganizations.HasAnyIgnoringCase(orgMembership) {
return nil, errors.New("user is not allowed to log in due to organization membership policy")
}
teamMembership, err := githubClient.GetTeamMembership(ctx, allowedOrgs)
teamMembership, err := githubClient.GetTeamMembership(ctx, p.c.AllowedOrganizations)
if err != nil {
return nil, err
}

View File

@@ -17,13 +17,13 @@ import (
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/util/cert"
supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/githubclient"
"go.pinniped.dev/internal/mocks/mockgithubclient"
"go.pinniped.dev/internal/setutil"
"go.pinniped.dev/internal/testutil/tlsserver"
)
@@ -45,7 +45,7 @@ func TestGitHubProvider(t *testing.T) {
AuthStyle: oauth2.AuthStyleInParams,
},
},
AllowedOrganizations: []string{"fake-org", "fake-org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"),
HttpClient: &http.Client{
Timeout: 1234509,
},
@@ -68,7 +68,7 @@ func TestGitHubProvider(t *testing.T) {
AuthStyle: oauth2.AuthStyleInParams,
},
},
AllowedOrganizations: []string{"fake-org", "fake-org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"),
HttpClient: &http.Client{
Timeout: 1234509,
},
@@ -80,7 +80,7 @@ func TestGitHubProvider(t *testing.T) {
require.Equal(t, "fake-client-id", subject.GetClientID())
require.Equal(t, supervisoridpv1alpha1.GitHubUsernameAttribute("fake-username-attribute"), subject.GetUsernameAttribute())
require.Equal(t, supervisoridpv1alpha1.GitHubGroupNameAttribute("fake-group-name-attribute"), subject.GetGroupNameAttribute())
require.Equal(t, []string{"fake-org", "fake-org2"}, subject.GetAllowedOrganizations())
require.Equal(t, setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"), subject.GetAllowedOrganizations())
require.Equal(t, "https://fake-authorization-url", subject.GetAuthorizationURL())
require.Equal(t, &http.Client{
Timeout: 1234509,
@@ -193,9 +193,6 @@ func TestGetUser(t *testing.T) {
const idpDisplayName = "idp display name 😀"
const encodedIDPDisplayName = "idp+display+name+%F0%9F%98%80"
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
someContext := context.Background()
someHttpClient := &http.Client{
@@ -223,7 +220,7 @@ func TestGetUser(t *testing.T) {
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, gomock.Any()).Return(nil, nil)
},
wantUser: &upstreamprovider.GitHubUser{
Username: "some-github-login:some-github-id",
@@ -243,7 +240,7 @@ func TestGetUser(t *testing.T) {
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil)
},
wantUser: &upstreamprovider.GitHubUser{
Username: "some-github-login",
@@ -263,7 +260,7 @@ func TestGetUser(t *testing.T) {
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil)
},
wantUser: &upstreamprovider.GitHubUser{
Username: "some-github-id",
@@ -276,15 +273,15 @@ func TestGetUser(t *testing.T) {
APIBaseURL: "https://some-url",
HttpClient: someHttpClient,
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2"),
},
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{
Login: "some-github-login",
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return(nil, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2")).Return(nil, nil)
},
wantUser: &upstreamprovider.GitHubUser{
Username: "some-github-login:some-github-id",
@@ -297,14 +294,14 @@ func TestGetUser(t *testing.T) {
APIBaseURL: "https://some-url",
HttpClient: someHttpClient,
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameID,
AllowedOrganizations: []string{"allowed-org"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org"),
},
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{
Login: "some-github-login",
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("disallowed-org"), nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"disallowed-org"}, nil)
},
wantErr: "user is not allowed to log in due to organization membership policy",
},
@@ -314,7 +311,7 @@ func TestGetUser(t *testing.T) {
APIBaseURL: "https://some-url",
HttpClient: someHttpClient,
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"),
GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamNameForGroupName,
},
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
@@ -322,8 +319,8 @@ func TestGetUser(t *testing.T) {
Login: "some-github-login",
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
{
Name: "org1-team1-name",
Slug: "org1-team1-slug",
@@ -353,7 +350,7 @@ func TestGetUser(t *testing.T) {
APIBaseURL: "https://some-url",
HttpClient: someHttpClient,
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"),
GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamSlugForGroupName,
},
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
@@ -361,8 +358,8 @@ func TestGetUser(t *testing.T) {
Login: "some-github-login",
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
{
Name: "org1-team1-name",
Slug: "org1-team1-slug",
@@ -462,7 +459,7 @@ func TestGetUser(t *testing.T) {
ID: "some-github-id",
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return([]githubclient.TeamInfo{
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return([]githubclient.TeamInfo{
{
Name: "org1-team1-name",
Slug: "org1-team1-slug",
@@ -477,6 +474,9 @@ func TestGetUser(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
accessToken := "some-opaque-github-access-token" + rand.String(8)
mockGitHubInterface := mockgithubclient.NewMockGitHubInterface(ctrl)
if test.buildMockResponses != nil {