From 8f8db3f54278bf54d6bccfc6ed54fde0f2797a84 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Tue, 21 May 2024 11:57:55 -0700 Subject: [PATCH] Make github org comparison case-insensitive, but return original case Co-authored-by: Joshua Casey --- .../github_upstream_watcher.go | 3 +- .../github_upstream_watcher_test.go | 16 ++-- .../resolved_github_provider_test.go | 3 +- .../upstreamprovider/upsteam_provider.go | 3 +- internal/githubclient/githubclient.go | 15 ++-- internal/githubclient/githubclient_test.go | 36 ++++----- .../mockgithubclient/mockgithubclient.go | 8 +- internal/setutil/setutil.go | 43 +++++++++++ internal/setutil/setutil_test.go | 34 +++++++++ internal/sliceutil/sliceutil.go | 13 ++++ internal/sliceutil/sliceutil_test.go | 74 +++++++++++++++++++ .../oidctestutil/testgithubprovider.go | 9 ++- internal/upstreamgithub/upstreamgithub.go | 12 ++- .../upstreamgithub/upstreamgithub_test.go | 44 +++++------ 14 files changed, 240 insertions(+), 73 deletions(-) create mode 100644 internal/setutil/setutil.go create mode 100644 internal/setutil/setutil_test.go create mode 100644 internal/sliceutil/sliceutil.go create mode 100644 internal/sliceutil/sliceutil_test.go diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index c6f65ff98..fc0c98d83 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -38,6 +38,7 @@ import ( "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/net/phttp" "go.pinniped.dev/internal/plog" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/internal/upstreamgithub" ) @@ -317,7 +318,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro RedirectURL: "", // this will be different for each FederationDomain, so we do not set it here Scopes: []string{"read:user", "read:org"}, }, - AllowedOrganizations: upstream.Spec.AllowAuthentication.Organizations.Allowed, + AllowedOrganizations: setutil.NewCaseInsensitiveSet(upstream.Spec.AllowAuthentication.Organizations.Allowed...), HttpClient: httpClient, }, ) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go index 162b0fd7d..7be242074 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go @@ -41,6 +41,7 @@ import ( "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/net/phttp" "go.pinniped.dev/internal/plog" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/tlsserver" "go.pinniped.dev/internal/upstreamgithub" @@ -406,7 +407,7 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - AllowedOrganizations: []string{"organization1", "org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"), HttpClient: nil, // let the test runner populate this for us }, }, @@ -462,7 +463,8 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - HttpClient: nil, // let the test runner populate this for us + AllowedOrganizations: setutil.NewCaseInsensitiveSet(), + HttpClient: nil, // let the test runner populate this for us }, }, wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{ @@ -531,7 +533,8 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - HttpClient: nil, // let the test runner populate this for us + AllowedOrganizations: setutil.NewCaseInsensitiveSet(), + HttpClient: nil, // let the test runner populate this for us }, }, wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{ @@ -598,7 +601,8 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - HttpClient: nil, // let the test runner populate this for us + AllowedOrganizations: setutil.NewCaseInsensitiveSet(), + HttpClient: nil, // let the test runner populate this for us }, }, wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{ @@ -685,7 +689,7 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - AllowedOrganizations: []string{"organization1", "org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"), HttpClient: nil, // let the test runner populate this for us }, { @@ -706,7 +710,7 @@ func TestController(t *testing.T) { RedirectURL: "", // not used Scopes: []string{"read:user", "read:org"}, }, - AllowedOrganizations: []string{"organization1", "org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"), HttpClient: nil, // let the test runner populate this for us }, }, diff --git a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go index 7941c104f..5a2b63b6e 100644 --- a/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go +++ b/internal/federationdomain/resolvedprovider/resolvedgithub/resolved_github_provider_test.go @@ -17,6 +17,7 @@ import ( "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/psession" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/transformtestutil" "go.pinniped.dev/internal/upstreamgithub" @@ -31,7 +32,7 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) { APIBaseURL: "https://fake-api-host.com", UsernameAttribute: idpv1alpha1.GitHubUsernameID, GroupNameAttribute: idpv1alpha1.GitHubUseTeamSlugForGroupName, - AllowedOrganizations: []string{"org1", "org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("org1", "org2"), HttpClient: nil, // not needed yet for this test OAuth2Config: &oauth2.Config{ ClientID: "fake-client-id", diff --git a/internal/federationdomain/upstreamprovider/upsteam_provider.go b/internal/federationdomain/upstreamprovider/upsteam_provider.go index c05785417..faf6af015 100644 --- a/internal/federationdomain/upstreamprovider/upsteam_provider.go +++ b/internal/federationdomain/upstreamprovider/upsteam_provider.go @@ -12,6 +12,7 @@ import ( "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/internal/authenticators" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -157,7 +158,7 @@ type UpstreamGithubIdentityProviderI interface { // and only teams from the listed organizations should be represented as groups for the downstream token. // If this list is empty, then any user can log in regardless of org membership, and any observable // teams memberships should be represented as groups for the downstream token. - GetAllowedOrganizations() []string + GetAllowedOrganizations() *setutil.CaseInsensitiveSet // GetAuthorizationURL returns the authorization URL for the configured GitHub. This will look like: // https:///login/oauth/authorize diff --git a/internal/githubclient/githubclient.go b/internal/githubclient/githubclient.go index c1889dcaa..36a7099e0 100644 --- a/internal/githubclient/githubclient.go +++ b/internal/githubclient/githubclient.go @@ -16,6 +16,7 @@ import ( "k8s.io/apimachinery/pkg/util/sets" "go.pinniped.dev/internal/plog" + "go.pinniped.dev/internal/setutil" ) const ( @@ -36,8 +37,8 @@ type TeamInfo struct { type GitHubInterface interface { GetUserInfo(ctx context.Context) (*UserInfo, error) - GetOrgMembership(ctx context.Context) (sets.Set[string], error) - GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error) + GetOrgMembership(ctx context.Context) ([]string, error) + GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error) } type githubClient struct { @@ -107,7 +108,7 @@ func (g *githubClient) GetUserInfo(ctx context.Context) (*UserInfo, error) { } // GetOrgMembership returns an array of the "Login" attributes for all organizations to which the authenticated user belongs. -func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string], error) { +func (g *githubClient) GetOrgMembership(ctx context.Context) ([]string, error) { const errorPrefix = "error fetching organizations for authenticated user" organizationLogins := sets.New[string]() @@ -135,11 +136,11 @@ func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string], } plog.Trace("calculated response from GitHub org membership endpoint", "orgs", organizationLogins.UnsortedList()) - return organizationLogins, nil + return organizationLogins.UnsortedList(), nil } -func isOrgAllowed(allowedOrganizations sets.Set[string], login string) bool { - return len(allowedOrganizations) == 0 || allowedOrganizations.Has(login) +func isOrgAllowed(allowedOrganizations *setutil.CaseInsensitiveSet, login string) bool { + return allowedOrganizations.Empty() || allowedOrganizations.ContainsIgnoringCase(login) } func buildAndValidateParentTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, error) { @@ -176,7 +177,7 @@ func buildTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, er // GetTeamMembership returns a description of each team to which the authenticated user belongs. // If allowedOrganizations is not empty, will filter the results to only those teams which belong to the allowed organizations. // Parent teams will also be returned. -func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error) { +func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error) { const errorPrefix = "error fetching team membership for authenticated user" teamInfos := sets.New[TeamInfo]() diff --git a/internal/githubclient/githubclient_test.go b/internal/githubclient/githubclient_test.go index 0ea8e7523..ae3b00eba 100644 --- a/internal/githubclient/githubclient_test.go +++ b/internal/githubclient/githubclient_test.go @@ -12,10 +12,10 @@ import ( "github.com/google/go-github/v62/github" "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/util/cert" "go.pinniped.dev/internal/net/phttp" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/internal/testutil/tlsserver" ) @@ -380,8 +380,7 @@ func TestGetOrgMembership(t *testing.T) { } require.NotNil(t, actual) - require.Equal(t, len(actual), len(test.wantOrgs)) - require.True(t, actual.HasAll(test.wantOrgs...)) + require.ElementsMatch(t, test.wantOrgs, actual) }) } } @@ -394,7 +393,7 @@ func TestGetTeamMembership(t *testing.T) { httpClient *http.Client token string ctx context.Context - allowedOrganizations []string + allowedOrganizations *setutil.CaseInsensitiveSet wantErr string wantTeams []TeamInfo }{ @@ -436,7 +435,7 @@ func TestGetTeamMembership(t *testing.T) { ), ), token: "some-token", - allowedOrganizations: []string{"alpha", "beta"}, + allowedOrganizations: setutil.NewCaseInsensitiveSet("alpha", "beta"), wantTeams: []TeamInfo{ { Name: "orgAlpha-team1-name", @@ -461,7 +460,7 @@ func TestGetTeamMembership(t *testing.T) { }, }, { - name: "filters by allowedOrganizations", + name: "filters by allowedOrganizations in a case-insensitive way, but preserves case as returned by GitHub API in the result", httpClient: mock.NewMockedHTTPClient( mock.WithRequestMatch( mock.GetUserTeams, @@ -470,38 +469,38 @@ func TestGetTeamMembership(t *testing.T) { Name: github.String("team1-name"), Slug: github.String("team1-slug"), Organization: &github.Organization{ - Login: github.String("alpha"), + Login: github.String("alPhA"), }, }, { Name: github.String("team2-name"), Slug: github.String("team2-slug"), Organization: &github.Organization{ - Login: github.String("beta"), + Login: github.String("bEtA"), }, }, { Name: github.String("team3-name"), Slug: github.String("team3-slug"), Organization: &github.Organization{ - Login: github.String("gamma"), + Login: github.String("gAmmA"), }, }, }, ), ), token: "some-token", - allowedOrganizations: []string{"alpha", "gamma"}, + allowedOrganizations: setutil.NewCaseInsensitiveSet("ALPHA", "gamma"), wantTeams: []TeamInfo{ { Name: "team1-name", Slug: "team1-slug", - Org: "alpha", + Org: "alPhA", }, { Name: "team3-name", Slug: "team3-slug", - Org: "gamma", + Org: "gAmmA", }, }, }, @@ -623,11 +622,8 @@ func TestGetTeamMembership(t *testing.T) { }, ), ), - token: "some-token", - allowedOrganizations: []string{ - "org-with-nested-teams", - "beta", - }, + token: "some-token", + allowedOrganizations: setutil.NewCaseInsensitiveSet("org-with-nested-teams", "beta"), wantTeams: []TeamInfo{ { Name: "team-name-without-parent", @@ -677,7 +673,7 @@ func TestGetTeamMembership(t *testing.T) { ), ), token: "some-token", - allowedOrganizations: []string{"page1-org-name", "page2-org-name"}, + allowedOrganizations: setutil.NewCaseInsensitiveSet("page1-org-name", "page2-org-name"), wantTeams: []TeamInfo{ { Name: "page1-team-name", @@ -770,7 +766,7 @@ func TestGetTeamMembership(t *testing.T) { ), ), token: "does-this-token-work", - allowedOrganizations: []string{"org-login"}, + allowedOrganizations: setutil.NewCaseInsensitiveSet("org-login"), wantTeams: []TeamInfo{ { Name: "team1-name", @@ -821,7 +817,7 @@ func TestGetTeamMembership(t *testing.T) { ctx = test.ctx } - actual, err := githubClient.GetTeamMembership(ctx, sets.New[string](test.allowedOrganizations...)) + actual, err := githubClient.GetTeamMembership(ctx, test.allowedOrganizations) if test.wantErr != "" { rt, ok := test.httpClient.Transport.(*mock.EnforceHostRoundTripper) require.True(t, ok) diff --git a/internal/mocks/mockgithubclient/mockgithubclient.go b/internal/mocks/mockgithubclient/mockgithubclient.go index 988031522..d259daadf 100644 --- a/internal/mocks/mockgithubclient/mockgithubclient.go +++ b/internal/mocks/mockgithubclient/mockgithubclient.go @@ -18,8 +18,8 @@ import ( reflect "reflect" githubclient "go.pinniped.dev/internal/githubclient" + setutil "go.pinniped.dev/internal/setutil" gomock "go.uber.org/mock/gomock" - sets "k8s.io/apimachinery/pkg/util/sets" ) // MockGitHubInterface is a mock of GitHubInterface interface. @@ -46,10 +46,10 @@ func (m *MockGitHubInterface) EXPECT() *MockGitHubInterfaceMockRecorder { } // GetOrgMembership mocks base method. -func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) (sets.Set[string], error) { +func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetOrgMembership", arg0) - ret0, _ := ret[0].(sets.Set[string]) + ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -61,7 +61,7 @@ func (mr *MockGitHubInterfaceMockRecorder) GetOrgMembership(arg0 any) *gomock.Ca } // GetTeamMembership mocks base method. -func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 sets.Set[string]) ([]githubclient.TeamInfo, error) { +func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 *setutil.CaseInsensitiveSet) ([]githubclient.TeamInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTeamMembership", arg0, arg1) ret0, _ := ret[0].([]githubclient.TeamInfo) diff --git a/internal/setutil/setutil.go b/internal/setutil/setutil.go new file mode 100644 index 000000000..714686c8d --- /dev/null +++ b/internal/setutil/setutil.go @@ -0,0 +1,43 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package setutil + +import ( + "strings" + + "k8s.io/apimachinery/pkg/util/sets" + + "go.pinniped.dev/internal/sliceutil" +) + +type CaseInsensitiveSet struct { + lowercasedContents sets.Set[string] +} + +func NewCaseInsensitiveSet(items ...string) *CaseInsensitiveSet { + return &CaseInsensitiveSet{ + lowercasedContents: sets.New(sliceutil.Map(items, strings.ToLower)...), + } +} + +func (s *CaseInsensitiveSet) HasAnyIgnoringCase(items []string) bool { + if s == nil { + return false + } + return s.lowercasedContents.HasAny(sliceutil.Map(items, strings.ToLower)...) +} + +func (s *CaseInsensitiveSet) ContainsIgnoringCase(item string) bool { + if s == nil { + return false + } + return s.lowercasedContents.Has(strings.ToLower(item)) +} + +func (s *CaseInsensitiveSet) Empty() bool { + if s == nil { + return true + } + return s.lowercasedContents.Len() == 0 +} diff --git a/internal/setutil/setutil_test.go b/internal/setutil/setutil_test.go new file mode 100644 index 000000000..0fb61ede8 --- /dev/null +++ b/internal/setutil/setutil_test.go @@ -0,0 +1,34 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package setutil + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCaseInsensitiveSet(t *testing.T) { + var nilSet *CaseInsensitiveSet + require.True(t, nilSet.Empty()) + require.False(t, nilSet.HasAnyIgnoringCase([]string{"a", "b"})) + require.False(t, nilSet.HasAnyIgnoringCase(nil)) + require.False(t, nilSet.ContainsIgnoringCase("a")) + require.False(t, nilSet.ContainsIgnoringCase("a")) + + emptySet := NewCaseInsensitiveSet() + require.True(t, emptySet.Empty()) + require.False(t, emptySet.HasAnyIgnoringCase([]string{"a", "b"})) + require.False(t, emptySet.HasAnyIgnoringCase(nil)) + require.False(t, emptySet.ContainsIgnoringCase("a")) + require.False(t, emptySet.ContainsIgnoringCase("a")) + + set := NewCaseInsensitiveSet("A", "B", "c") + require.False(t, set.Empty()) + require.False(t, set.HasAnyIgnoringCase([]string{"x", "y"})) + require.True(t, set.HasAnyIgnoringCase([]string{"a", "x"})) + require.False(t, set.HasAnyIgnoringCase(nil)) + require.False(t, set.ContainsIgnoringCase("x")) + require.True(t, set.ContainsIgnoringCase("a")) +} diff --git a/internal/sliceutil/sliceutil.go b/internal/sliceutil/sliceutil.go new file mode 100644 index 000000000..af243ab4a --- /dev/null +++ b/internal/sliceutil/sliceutil.go @@ -0,0 +1,13 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package sliceutil + +// Map transforms a slice from an input type I to an output type O using a transform func. +func Map[I, O any](in []I, transform func(I) O) []O { + out := make([]O, len(in)) + for i := range in { + out[i] = transform(in[i]) + } + return out +} diff --git a/internal/sliceutil/sliceutil_test.go b/internal/sliceutil/sliceutil_test.go new file mode 100644 index 000000000..389603a12 --- /dev/null +++ b/internal/sliceutil/sliceutil_test.go @@ -0,0 +1,74 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package sliceutil + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMap(t *testing.T) { + type testCase[I any, O any] struct { + name string + in []I + transformFunc func(I) O + want []O + } + + stringStringTests := []testCase[string, string]{ + { + name: "downcase func", + in: []string{"Aa", "bB", "CC"}, + transformFunc: strings.ToLower, + want: []string{"aa", "bb", "cc"}, + }, + { + name: "upcase func", + in: []string{"Aa", "bB", "CC"}, + transformFunc: strings.ToUpper, + want: []string{"AA", "BB", "CC"}, + }, + { + name: "when in is nil, then out is an empty slice", + in: nil, + transformFunc: strings.ToUpper, + want: []string{}, + }, + } + for _, tt := range stringStringTests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := Map(tt.in, tt.transformFunc) + require.Equal(t, tt.want, actual) + }) + } + + stringIntTests := []testCase[string, int]{ + { + name: "len func", + in: []string{"Aa", "bBb", "CCcC"}, + transformFunc: func(s string) int { + return len(s) + }, + want: []int{2, 3, 4}, + }, + { + name: "index func", + in: []string{"Aab", "bB", "CC"}, + transformFunc: func(s string) int { + return strings.Index(s, "b") + }, + want: []int{2, 0, -1}, + }, + } + for _, tt := range stringIntTests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := Map(tt.in, tt.transformFunc) + require.Equal(t, tt.want, actual) + }) + } +} diff --git a/internal/testutil/oidctestutil/testgithubprovider.go b/internal/testutil/oidctestutil/testgithubprovider.go index 1eea5896c..aa210e448 100644 --- a/internal/testutil/oidctestutil/testgithubprovider.go +++ b/internal/testutil/oidctestutil/testgithubprovider.go @@ -11,6 +11,7 @@ import ( "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/idtransform" + "go.pinniped.dev/internal/setutil" ) // ExchangeAuthcodeArgs is used to spy on calls to @@ -38,7 +39,7 @@ type TestUpstreamGitHubIdentityProviderBuilder struct { transformsForFederationDomain *idtransform.TransformationPipeline usernameAttribute v1alpha1.GitHubUsernameAttribute groupNameAttribute v1alpha1.GitHubGroupNameAttribute - allowedOrganizations []string + allowedOrganizations *setutil.CaseInsensitiveSet authorizationURL string authcodeExchangeErr error accessToken string @@ -81,7 +82,7 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGroupNameAttribute(value return u } -func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value []string) *TestUpstreamGitHubIdentityProviderBuilder { +func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value *setutil.CaseInsensitiveSet) *TestUpstreamGitHubIdentityProviderBuilder { u.allowedOrganizations = value return u } @@ -159,7 +160,7 @@ type TestUpstreamGitHubIdentityProvider struct { TransformsForFederationDomain *idtransform.TransformationPipeline UsernameAttribute v1alpha1.GitHubUsernameAttribute GroupNameAttribute v1alpha1.GitHubGroupNameAttribute - AllowedOrganizations []string + AllowedOrganizations *setutil.CaseInsensitiveSet AuthorizationURL string GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error) ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error) @@ -197,7 +198,7 @@ func (u *TestUpstreamGitHubIdentityProvider) GetGroupNameAttribute() v1alpha1.Gi return u.GroupNameAttribute } -func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() []string { +func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet { return u.AllowedOrganizations } diff --git a/internal/upstreamgithub/upstreamgithub.go b/internal/upstreamgithub/upstreamgithub.go index 7321ac35e..d1085e466 100644 --- a/internal/upstreamgithub/upstreamgithub.go +++ b/internal/upstreamgithub/upstreamgithub.go @@ -13,12 +13,12 @@ import ( coreosoidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/sets" supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/internal/federationdomain/downstreamsubject" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/githubclient" + "go.pinniped.dev/internal/setutil" ) // ProviderConfig holds the active configuration of an upstream GitHub provider. @@ -35,7 +35,7 @@ type ProviderConfig struct { GroupNameAttribute supervisoridpv1alpha1.GitHubGroupNameAttribute // AllowedOrganizations, when empty, means to allow users from all orgs. - AllowedOrganizations []string + AllowedOrganizations *setutil.CaseInsensitiveSet // HttpClient is a client that can be used to call the GitHub APIs and token endpoint. // This client should be configured with the user-provided CA bundle and a timeout. @@ -90,7 +90,7 @@ func (p *Provider) GetGroupNameAttribute() supervisoridpv1alpha1.GitHubGroupName return p.c.GroupNameAttribute } -func (p *Provider) GetAllowedOrganizations() []string { +func (p *Provider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet { return p.c.AllowedOrganizations } @@ -146,13 +146,11 @@ func (p *Provider) GetUser(ctx context.Context, accessToken string, idpDisplayNa return nil, err } - allowedOrgs := sets.New[string](p.c.AllowedOrganizations...) - - if allowedOrgs.Len() > 0 && allowedOrgs.Intersection(orgMembership).Len() < 1 { + if !p.c.AllowedOrganizations.Empty() && !p.c.AllowedOrganizations.HasAnyIgnoringCase(orgMembership) { return nil, errors.New("user is not allowed to log in due to organization membership policy") } - teamMembership, err := githubClient.GetTeamMembership(ctx, allowedOrgs) + teamMembership, err := githubClient.GetTeamMembership(ctx, p.c.AllowedOrganizations) if err != nil { return nil, err } diff --git a/internal/upstreamgithub/upstreamgithub_test.go b/internal/upstreamgithub/upstreamgithub_test.go index 5b8ae3c6d..f9de917d3 100644 --- a/internal/upstreamgithub/upstreamgithub_test.go +++ b/internal/upstreamgithub/upstreamgithub_test.go @@ -17,13 +17,13 @@ import ( "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/rand" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/util/cert" supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/githubclient" "go.pinniped.dev/internal/mocks/mockgithubclient" + "go.pinniped.dev/internal/setutil" "go.pinniped.dev/internal/testutil/tlsserver" ) @@ -45,7 +45,7 @@ func TestGitHubProvider(t *testing.T) { AuthStyle: oauth2.AuthStyleInParams, }, }, - AllowedOrganizations: []string{"fake-org", "fake-org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"), HttpClient: &http.Client{ Timeout: 1234509, }, @@ -68,7 +68,7 @@ func TestGitHubProvider(t *testing.T) { AuthStyle: oauth2.AuthStyleInParams, }, }, - AllowedOrganizations: []string{"fake-org", "fake-org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"), HttpClient: &http.Client{ Timeout: 1234509, }, @@ -80,7 +80,7 @@ func TestGitHubProvider(t *testing.T) { require.Equal(t, "fake-client-id", subject.GetClientID()) require.Equal(t, supervisoridpv1alpha1.GitHubUsernameAttribute("fake-username-attribute"), subject.GetUsernameAttribute()) require.Equal(t, supervisoridpv1alpha1.GitHubGroupNameAttribute("fake-group-name-attribute"), subject.GetGroupNameAttribute()) - require.Equal(t, []string{"fake-org", "fake-org2"}, subject.GetAllowedOrganizations()) + require.Equal(t, setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"), subject.GetAllowedOrganizations()) require.Equal(t, "https://fake-authorization-url", subject.GetAuthorizationURL()) require.Equal(t, &http.Client{ Timeout: 1234509, @@ -193,9 +193,6 @@ func TestGetUser(t *testing.T) { const idpDisplayName = "idp display name 😀" const encodedIDPDisplayName = "idp+display+name+%F0%9F%98%80" - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - someContext := context.Background() someHttpClient := &http.Client{ @@ -223,7 +220,7 @@ func TestGetUser(t *testing.T) { ID: "some-github-id", }, nil) mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, gomock.Any()).Return(nil, nil) }, wantUser: &upstreamprovider.GitHubUser{ Username: "some-github-login:some-github-id", @@ -243,7 +240,7 @@ func TestGetUser(t *testing.T) { ID: "some-github-id", }, nil) mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil) }, wantUser: &upstreamprovider.GitHubUser{ Username: "some-github-login", @@ -263,7 +260,7 @@ func TestGetUser(t *testing.T) { ID: "some-github-id", }, nil) mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil) }, wantUser: &upstreamprovider.GitHubUser{ Username: "some-github-id", @@ -276,15 +273,15 @@ func TestGetUser(t *testing.T) { APIBaseURL: "https://some-url", HttpClient: someHttpClient, UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID, - AllowedOrganizations: []string{"allowed-org1", "allowed-org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2"), }, buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) { mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{ Login: "some-github-login", ID: "some-github-id", }, nil) - mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return(nil, nil) + mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2")).Return(nil, nil) }, wantUser: &upstreamprovider.GitHubUser{ Username: "some-github-login:some-github-id", @@ -297,14 +294,14 @@ func TestGetUser(t *testing.T) { APIBaseURL: "https://some-url", HttpClient: someHttpClient, UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameID, - AllowedOrganizations: []string{"allowed-org"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org"), }, buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) { mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{ Login: "some-github-login", ID: "some-github-id", }, nil) - mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("disallowed-org"), nil) + mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"disallowed-org"}, nil) }, wantErr: "user is not allowed to log in due to organization membership policy", }, @@ -314,7 +311,7 @@ func TestGetUser(t *testing.T) { APIBaseURL: "https://some-url", HttpClient: someHttpClient, UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID, - AllowedOrganizations: []string{"allowed-org1", "allowed-org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"), GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamNameForGroupName, }, buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) { @@ -322,8 +319,8 @@ func TestGetUser(t *testing.T) { Login: "some-github-login", ID: "some-github-id", }, nil) - mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{ + mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{ { Name: "org1-team1-name", Slug: "org1-team1-slug", @@ -353,7 +350,7 @@ func TestGetUser(t *testing.T) { APIBaseURL: "https://some-url", HttpClient: someHttpClient, UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID, - AllowedOrganizations: []string{"allowed-org1", "allowed-org2"}, + AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"), GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamSlugForGroupName, }, buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) { @@ -361,8 +358,8 @@ func TestGetUser(t *testing.T) { Login: "some-github-login", ID: "some-github-id", }, nil) - mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{ + mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil) + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{ { Name: "org1-team1-name", Slug: "org1-team1-slug", @@ -462,7 +459,7 @@ func TestGetUser(t *testing.T) { ID: "some-github-id", }, nil) mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil) - mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return([]githubclient.TeamInfo{ + mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return([]githubclient.TeamInfo{ { Name: "org1-team1-name", Slug: "org1-team1-slug", @@ -477,6 +474,9 @@ func TestGetUser(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + accessToken := "some-opaque-github-access-token" + rand.String(8) mockGitHubInterface := mockgithubclient.NewMockGitHubInterface(ctrl) if test.buildMockResponses != nil {