mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2025-12-23 06:15:47 +00:00
Make github org comparison case-insensitive, but return original case
Co-authored-by: Joshua Casey <joshuatcasey@gmail.com>
This commit is contained in:
committed by
Joshua Casey
parent
8923704f3c
commit
8f8db3f542
@@ -38,6 +38,7 @@ import (
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/net/phttp"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/internal/upstreamgithub"
|
||||
)
|
||||
|
||||
@@ -317,7 +318,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
|
||||
RedirectURL: "", // this will be different for each FederationDomain, so we do not set it here
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
AllowedOrganizations: upstream.Spec.AllowAuthentication.Organizations.Allowed,
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet(upstream.Spec.AllowAuthentication.Organizations.Allowed...),
|
||||
HttpClient: httpClient,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -41,6 +41,7 @@ import (
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/net/phttp"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/internal/testutil/tlsserver"
|
||||
"go.pinniped.dev/internal/upstreamgithub"
|
||||
@@ -406,7 +407,7 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
AllowedOrganizations: []string{"organization1", "org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
},
|
||||
@@ -462,7 +463,8 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
},
|
||||
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
|
||||
@@ -531,7 +533,8 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
},
|
||||
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
|
||||
@@ -598,7 +601,8 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet(),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
},
|
||||
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
|
||||
@@ -685,7 +689,7 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
AllowedOrganizations: []string{"organization1", "org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
{
|
||||
@@ -706,7 +710,7 @@ func TestController(t *testing.T) {
|
||||
RedirectURL: "", // not used
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
},
|
||||
AllowedOrganizations: []string{"organization1", "org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("organization1", "org2"),
|
||||
HttpClient: nil, // let the test runner populate this for us
|
||||
},
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/psession"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/internal/testutil/oidctestutil"
|
||||
"go.pinniped.dev/internal/testutil/transformtestutil"
|
||||
"go.pinniped.dev/internal/upstreamgithub"
|
||||
@@ -31,7 +32,7 @@ func TestFederationDomainResolvedGitHubIdentityProvider(t *testing.T) {
|
||||
APIBaseURL: "https://fake-api-host.com",
|
||||
UsernameAttribute: idpv1alpha1.GitHubUsernameID,
|
||||
GroupNameAttribute: idpv1alpha1.GitHubUseTeamSlugForGroupName,
|
||||
AllowedOrganizations: []string{"org1", "org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("org1", "org2"),
|
||||
HttpClient: nil, // not needed yet for this test
|
||||
OAuth2Config: &oauth2.Config{
|
||||
ClientID: "fake-client-id",
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
|
||||
"go.pinniped.dev/internal/authenticators"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
@@ -157,7 +158,7 @@ type UpstreamGithubIdentityProviderI interface {
|
||||
// and only teams from the listed organizations should be represented as groups for the downstream token.
|
||||
// If this list is empty, then any user can log in regardless of org membership, and any observable
|
||||
// teams memberships should be represented as groups for the downstream token.
|
||||
GetAllowedOrganizations() []string
|
||||
GetAllowedOrganizations() *setutil.CaseInsensitiveSet
|
||||
|
||||
// GetAuthorizationURL returns the authorization URL for the configured GitHub. This will look like:
|
||||
// https://<spec.githubAPI.host>/login/oauth/authorize
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,8 +37,8 @@ type TeamInfo struct {
|
||||
|
||||
type GitHubInterface interface {
|
||||
GetUserInfo(ctx context.Context) (*UserInfo, error)
|
||||
GetOrgMembership(ctx context.Context) (sets.Set[string], error)
|
||||
GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error)
|
||||
GetOrgMembership(ctx context.Context) ([]string, error)
|
||||
GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error)
|
||||
}
|
||||
|
||||
type githubClient struct {
|
||||
@@ -107,7 +108,7 @@ func (g *githubClient) GetUserInfo(ctx context.Context) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
// GetOrgMembership returns an array of the "Login" attributes for all organizations to which the authenticated user belongs.
|
||||
func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string], error) {
|
||||
func (g *githubClient) GetOrgMembership(ctx context.Context) ([]string, error) {
|
||||
const errorPrefix = "error fetching organizations for authenticated user"
|
||||
|
||||
organizationLogins := sets.New[string]()
|
||||
@@ -135,11 +136,11 @@ func (g *githubClient) GetOrgMembership(ctx context.Context) (sets.Set[string],
|
||||
}
|
||||
|
||||
plog.Trace("calculated response from GitHub org membership endpoint", "orgs", organizationLogins.UnsortedList())
|
||||
return organizationLogins, nil
|
||||
return organizationLogins.UnsortedList(), nil
|
||||
}
|
||||
|
||||
func isOrgAllowed(allowedOrganizations sets.Set[string], login string) bool {
|
||||
return len(allowedOrganizations) == 0 || allowedOrganizations.Has(login)
|
||||
func isOrgAllowed(allowedOrganizations *setutil.CaseInsensitiveSet, login string) bool {
|
||||
return allowedOrganizations.Empty() || allowedOrganizations.ContainsIgnoringCase(login)
|
||||
}
|
||||
|
||||
func buildAndValidateParentTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, error) {
|
||||
@@ -176,7 +177,7 @@ func buildTeam(githubTeam *github.Team, organizationLogin string) (*TeamInfo, er
|
||||
// GetTeamMembership returns a description of each team to which the authenticated user belongs.
|
||||
// If allowedOrganizations is not empty, will filter the results to only those teams which belong to the allowed organizations.
|
||||
// Parent teams will also be returned.
|
||||
func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations sets.Set[string]) ([]TeamInfo, error) {
|
||||
func (g *githubClient) GetTeamMembership(ctx context.Context, allowedOrganizations *setutil.CaseInsensitiveSet) ([]TeamInfo, error) {
|
||||
const errorPrefix = "error fetching team membership for authenticated user"
|
||||
teamInfos := sets.New[TeamInfo]()
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
"github.com/google/go-github/v62/github"
|
||||
"github.com/migueleliasweb/go-github-mock/src/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/client-go/util/cert"
|
||||
|
||||
"go.pinniped.dev/internal/net/phttp"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/internal/testutil/tlsserver"
|
||||
)
|
||||
|
||||
@@ -380,8 +380,7 @@ func TestGetOrgMembership(t *testing.T) {
|
||||
}
|
||||
|
||||
require.NotNil(t, actual)
|
||||
require.Equal(t, len(actual), len(test.wantOrgs))
|
||||
require.True(t, actual.HasAll(test.wantOrgs...))
|
||||
require.ElementsMatch(t, test.wantOrgs, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -394,7 +393,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
httpClient *http.Client
|
||||
token string
|
||||
ctx context.Context
|
||||
allowedOrganizations []string
|
||||
allowedOrganizations *setutil.CaseInsensitiveSet
|
||||
wantErr string
|
||||
wantTeams []TeamInfo
|
||||
}{
|
||||
@@ -436,7 +435,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
),
|
||||
),
|
||||
token: "some-token",
|
||||
allowedOrganizations: []string{"alpha", "beta"},
|
||||
allowedOrganizations: setutil.NewCaseInsensitiveSet("alpha", "beta"),
|
||||
wantTeams: []TeamInfo{
|
||||
{
|
||||
Name: "orgAlpha-team1-name",
|
||||
@@ -461,7 +460,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filters by allowedOrganizations",
|
||||
name: "filters by allowedOrganizations in a case-insensitive way, but preserves case as returned by GitHub API in the result",
|
||||
httpClient: mock.NewMockedHTTPClient(
|
||||
mock.WithRequestMatch(
|
||||
mock.GetUserTeams,
|
||||
@@ -470,38 +469,38 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
Name: github.String("team1-name"),
|
||||
Slug: github.String("team1-slug"),
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("alpha"),
|
||||
Login: github.String("alPhA"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: github.String("team2-name"),
|
||||
Slug: github.String("team2-slug"),
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("beta"),
|
||||
Login: github.String("bEtA"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: github.String("team3-name"),
|
||||
Slug: github.String("team3-slug"),
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("gamma"),
|
||||
Login: github.String("gAmmA"),
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
),
|
||||
token: "some-token",
|
||||
allowedOrganizations: []string{"alpha", "gamma"},
|
||||
allowedOrganizations: setutil.NewCaseInsensitiveSet("ALPHA", "gamma"),
|
||||
wantTeams: []TeamInfo{
|
||||
{
|
||||
Name: "team1-name",
|
||||
Slug: "team1-slug",
|
||||
Org: "alpha",
|
||||
Org: "alPhA",
|
||||
},
|
||||
{
|
||||
Name: "team3-name",
|
||||
Slug: "team3-slug",
|
||||
Org: "gamma",
|
||||
Org: "gAmmA",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -623,11 +622,8 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
},
|
||||
),
|
||||
),
|
||||
token: "some-token",
|
||||
allowedOrganizations: []string{
|
||||
"org-with-nested-teams",
|
||||
"beta",
|
||||
},
|
||||
token: "some-token",
|
||||
allowedOrganizations: setutil.NewCaseInsensitiveSet("org-with-nested-teams", "beta"),
|
||||
wantTeams: []TeamInfo{
|
||||
{
|
||||
Name: "team-name-without-parent",
|
||||
@@ -677,7 +673,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
),
|
||||
),
|
||||
token: "some-token",
|
||||
allowedOrganizations: []string{"page1-org-name", "page2-org-name"},
|
||||
allowedOrganizations: setutil.NewCaseInsensitiveSet("page1-org-name", "page2-org-name"),
|
||||
wantTeams: []TeamInfo{
|
||||
{
|
||||
Name: "page1-team-name",
|
||||
@@ -770,7 +766,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
),
|
||||
),
|
||||
token: "does-this-token-work",
|
||||
allowedOrganizations: []string{"org-login"},
|
||||
allowedOrganizations: setutil.NewCaseInsensitiveSet("org-login"),
|
||||
wantTeams: []TeamInfo{
|
||||
{
|
||||
Name: "team1-name",
|
||||
@@ -821,7 +817,7 @@ func TestGetTeamMembership(t *testing.T) {
|
||||
ctx = test.ctx
|
||||
}
|
||||
|
||||
actual, err := githubClient.GetTeamMembership(ctx, sets.New[string](test.allowedOrganizations...))
|
||||
actual, err := githubClient.GetTeamMembership(ctx, test.allowedOrganizations)
|
||||
if test.wantErr != "" {
|
||||
rt, ok := test.httpClient.Transport.(*mock.EnforceHostRoundTripper)
|
||||
require.True(t, ok)
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
reflect "reflect"
|
||||
|
||||
githubclient "go.pinniped.dev/internal/githubclient"
|
||||
setutil "go.pinniped.dev/internal/setutil"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
sets "k8s.io/apimachinery/pkg/util/sets"
|
||||
)
|
||||
|
||||
// MockGitHubInterface is a mock of GitHubInterface interface.
|
||||
@@ -46,10 +46,10 @@ func (m *MockGitHubInterface) EXPECT() *MockGitHubInterfaceMockRecorder {
|
||||
}
|
||||
|
||||
// GetOrgMembership mocks base method.
|
||||
func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) (sets.Set[string], error) {
|
||||
func (m *MockGitHubInterface) GetOrgMembership(arg0 context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrgMembership", arg0)
|
||||
ret0, _ := ret[0].(sets.Set[string])
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (mr *MockGitHubInterfaceMockRecorder) GetOrgMembership(arg0 any) *gomock.Ca
|
||||
}
|
||||
|
||||
// GetTeamMembership mocks base method.
|
||||
func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 sets.Set[string]) ([]githubclient.TeamInfo, error) {
|
||||
func (m *MockGitHubInterface) GetTeamMembership(arg0 context.Context, arg1 *setutil.CaseInsensitiveSet) ([]githubclient.TeamInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTeamMembership", arg0, arg1)
|
||||
ret0, _ := ret[0].([]githubclient.TeamInfo)
|
||||
|
||||
43
internal/setutil/setutil.go
Normal file
43
internal/setutil/setutil.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package setutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
|
||||
"go.pinniped.dev/internal/sliceutil"
|
||||
)
|
||||
|
||||
type CaseInsensitiveSet struct {
|
||||
lowercasedContents sets.Set[string]
|
||||
}
|
||||
|
||||
func NewCaseInsensitiveSet(items ...string) *CaseInsensitiveSet {
|
||||
return &CaseInsensitiveSet{
|
||||
lowercasedContents: sets.New(sliceutil.Map(items, strings.ToLower)...),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CaseInsensitiveSet) HasAnyIgnoringCase(items []string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.lowercasedContents.HasAny(sliceutil.Map(items, strings.ToLower)...)
|
||||
}
|
||||
|
||||
func (s *CaseInsensitiveSet) ContainsIgnoringCase(item string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.lowercasedContents.Has(strings.ToLower(item))
|
||||
}
|
||||
|
||||
func (s *CaseInsensitiveSet) Empty() bool {
|
||||
if s == nil {
|
||||
return true
|
||||
}
|
||||
return s.lowercasedContents.Len() == 0
|
||||
}
|
||||
34
internal/setutil/setutil_test.go
Normal file
34
internal/setutil/setutil_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package setutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCaseInsensitiveSet(t *testing.T) {
|
||||
var nilSet *CaseInsensitiveSet
|
||||
require.True(t, nilSet.Empty())
|
||||
require.False(t, nilSet.HasAnyIgnoringCase([]string{"a", "b"}))
|
||||
require.False(t, nilSet.HasAnyIgnoringCase(nil))
|
||||
require.False(t, nilSet.ContainsIgnoringCase("a"))
|
||||
require.False(t, nilSet.ContainsIgnoringCase("a"))
|
||||
|
||||
emptySet := NewCaseInsensitiveSet()
|
||||
require.True(t, emptySet.Empty())
|
||||
require.False(t, emptySet.HasAnyIgnoringCase([]string{"a", "b"}))
|
||||
require.False(t, emptySet.HasAnyIgnoringCase(nil))
|
||||
require.False(t, emptySet.ContainsIgnoringCase("a"))
|
||||
require.False(t, emptySet.ContainsIgnoringCase("a"))
|
||||
|
||||
set := NewCaseInsensitiveSet("A", "B", "c")
|
||||
require.False(t, set.Empty())
|
||||
require.False(t, set.HasAnyIgnoringCase([]string{"x", "y"}))
|
||||
require.True(t, set.HasAnyIgnoringCase([]string{"a", "x"}))
|
||||
require.False(t, set.HasAnyIgnoringCase(nil))
|
||||
require.False(t, set.ContainsIgnoringCase("x"))
|
||||
require.True(t, set.ContainsIgnoringCase("a"))
|
||||
}
|
||||
13
internal/sliceutil/sliceutil.go
Normal file
13
internal/sliceutil/sliceutil.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package sliceutil
|
||||
|
||||
// Map transforms a slice from an input type I to an output type O using a transform func.
|
||||
func Map[I, O any](in []I, transform func(I) O) []O {
|
||||
out := make([]O, len(in))
|
||||
for i := range in {
|
||||
out[i] = transform(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
74
internal/sliceutil/sliceutil_test.go
Normal file
74
internal/sliceutil/sliceutil_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package sliceutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
type testCase[I any, O any] struct {
|
||||
name string
|
||||
in []I
|
||||
transformFunc func(I) O
|
||||
want []O
|
||||
}
|
||||
|
||||
stringStringTests := []testCase[string, string]{
|
||||
{
|
||||
name: "downcase func",
|
||||
in: []string{"Aa", "bB", "CC"},
|
||||
transformFunc: strings.ToLower,
|
||||
want: []string{"aa", "bb", "cc"},
|
||||
},
|
||||
{
|
||||
name: "upcase func",
|
||||
in: []string{"Aa", "bB", "CC"},
|
||||
transformFunc: strings.ToUpper,
|
||||
want: []string{"AA", "BB", "CC"},
|
||||
},
|
||||
{
|
||||
name: "when in is nil, then out is an empty slice",
|
||||
in: nil,
|
||||
transformFunc: strings.ToUpper,
|
||||
want: []string{},
|
||||
},
|
||||
}
|
||||
for _, tt := range stringStringTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
actual := Map(tt.in, tt.transformFunc)
|
||||
require.Equal(t, tt.want, actual)
|
||||
})
|
||||
}
|
||||
|
||||
stringIntTests := []testCase[string, int]{
|
||||
{
|
||||
name: "len func",
|
||||
in: []string{"Aa", "bBb", "CCcC"},
|
||||
transformFunc: func(s string) int {
|
||||
return len(s)
|
||||
},
|
||||
want: []int{2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "index func",
|
||||
in: []string{"Aab", "bB", "CC"},
|
||||
transformFunc: func(s string) int {
|
||||
return strings.Index(s, "b")
|
||||
},
|
||||
want: []int{2, 0, -1},
|
||||
},
|
||||
}
|
||||
for _, tt := range stringIntTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
actual := Map(tt.in, tt.transformFunc)
|
||||
require.Equal(t, tt.want, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/idtransform"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
)
|
||||
|
||||
// ExchangeAuthcodeArgs is used to spy on calls to
|
||||
@@ -38,7 +39,7 @@ type TestUpstreamGitHubIdentityProviderBuilder struct {
|
||||
transformsForFederationDomain *idtransform.TransformationPipeline
|
||||
usernameAttribute v1alpha1.GitHubUsernameAttribute
|
||||
groupNameAttribute v1alpha1.GitHubGroupNameAttribute
|
||||
allowedOrganizations []string
|
||||
allowedOrganizations *setutil.CaseInsensitiveSet
|
||||
authorizationURL string
|
||||
authcodeExchangeErr error
|
||||
accessToken string
|
||||
@@ -81,7 +82,7 @@ func (u *TestUpstreamGitHubIdentityProviderBuilder) WithGroupNameAttribute(value
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value []string) *TestUpstreamGitHubIdentityProviderBuilder {
|
||||
func (u *TestUpstreamGitHubIdentityProviderBuilder) WithAllowedOrganizations(value *setutil.CaseInsensitiveSet) *TestUpstreamGitHubIdentityProviderBuilder {
|
||||
u.allowedOrganizations = value
|
||||
return u
|
||||
}
|
||||
@@ -159,7 +160,7 @@ type TestUpstreamGitHubIdentityProvider struct {
|
||||
TransformsForFederationDomain *idtransform.TransformationPipeline
|
||||
UsernameAttribute v1alpha1.GitHubUsernameAttribute
|
||||
GroupNameAttribute v1alpha1.GitHubGroupNameAttribute
|
||||
AllowedOrganizations []string
|
||||
AllowedOrganizations *setutil.CaseInsensitiveSet
|
||||
AuthorizationURL string
|
||||
GetUserFunc func(ctx context.Context, accessToken string) (*upstreamprovider.GitHubUser, error)
|
||||
ExchangeAuthcodeFunc func(ctx context.Context, authcode string) (string, error)
|
||||
@@ -197,7 +198,7 @@ func (u *TestUpstreamGitHubIdentityProvider) GetGroupNameAttribute() v1alpha1.Gi
|
||||
return u.GroupNameAttribute
|
||||
}
|
||||
|
||||
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() []string {
|
||||
func (u *TestUpstreamGitHubIdentityProvider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
|
||||
return u.AllowedOrganizations
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
coreosoidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
|
||||
supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
|
||||
"go.pinniped.dev/internal/federationdomain/downstreamsubject"
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/githubclient"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
)
|
||||
|
||||
// ProviderConfig holds the active configuration of an upstream GitHub provider.
|
||||
@@ -35,7 +35,7 @@ type ProviderConfig struct {
|
||||
GroupNameAttribute supervisoridpv1alpha1.GitHubGroupNameAttribute
|
||||
|
||||
// AllowedOrganizations, when empty, means to allow users from all orgs.
|
||||
AllowedOrganizations []string
|
||||
AllowedOrganizations *setutil.CaseInsensitiveSet
|
||||
|
||||
// HttpClient is a client that can be used to call the GitHub APIs and token endpoint.
|
||||
// This client should be configured with the user-provided CA bundle and a timeout.
|
||||
@@ -90,7 +90,7 @@ func (p *Provider) GetGroupNameAttribute() supervisoridpv1alpha1.GitHubGroupName
|
||||
return p.c.GroupNameAttribute
|
||||
}
|
||||
|
||||
func (p *Provider) GetAllowedOrganizations() []string {
|
||||
func (p *Provider) GetAllowedOrganizations() *setutil.CaseInsensitiveSet {
|
||||
return p.c.AllowedOrganizations
|
||||
}
|
||||
|
||||
@@ -146,13 +146,11 @@ func (p *Provider) GetUser(ctx context.Context, accessToken string, idpDisplayNa
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedOrgs := sets.New[string](p.c.AllowedOrganizations...)
|
||||
|
||||
if allowedOrgs.Len() > 0 && allowedOrgs.Intersection(orgMembership).Len() < 1 {
|
||||
if !p.c.AllowedOrganizations.Empty() && !p.c.AllowedOrganizations.HasAnyIgnoringCase(orgMembership) {
|
||||
return nil, errors.New("user is not allowed to log in due to organization membership policy")
|
||||
}
|
||||
|
||||
teamMembership, err := githubClient.GetTeamMembership(ctx, allowedOrgs)
|
||||
teamMembership, err := githubClient.GetTeamMembership(ctx, p.c.AllowedOrganizations)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/rand"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/client-go/util/cert"
|
||||
|
||||
supervisoridpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/githubclient"
|
||||
"go.pinniped.dev/internal/mocks/mockgithubclient"
|
||||
"go.pinniped.dev/internal/setutil"
|
||||
"go.pinniped.dev/internal/testutil/tlsserver"
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestGitHubProvider(t *testing.T) {
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
},
|
||||
AllowedOrganizations: []string{"fake-org", "fake-org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"),
|
||||
HttpClient: &http.Client{
|
||||
Timeout: 1234509,
|
||||
},
|
||||
@@ -68,7 +68,7 @@ func TestGitHubProvider(t *testing.T) {
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
},
|
||||
AllowedOrganizations: []string{"fake-org", "fake-org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"),
|
||||
HttpClient: &http.Client{
|
||||
Timeout: 1234509,
|
||||
},
|
||||
@@ -80,7 +80,7 @@ func TestGitHubProvider(t *testing.T) {
|
||||
require.Equal(t, "fake-client-id", subject.GetClientID())
|
||||
require.Equal(t, supervisoridpv1alpha1.GitHubUsernameAttribute("fake-username-attribute"), subject.GetUsernameAttribute())
|
||||
require.Equal(t, supervisoridpv1alpha1.GitHubGroupNameAttribute("fake-group-name-attribute"), subject.GetGroupNameAttribute())
|
||||
require.Equal(t, []string{"fake-org", "fake-org2"}, subject.GetAllowedOrganizations())
|
||||
require.Equal(t, setutil.NewCaseInsensitiveSet("fake-org", "fake-org2"), subject.GetAllowedOrganizations())
|
||||
require.Equal(t, "https://fake-authorization-url", subject.GetAuthorizationURL())
|
||||
require.Equal(t, &http.Client{
|
||||
Timeout: 1234509,
|
||||
@@ -193,9 +193,6 @@ func TestGetUser(t *testing.T) {
|
||||
const idpDisplayName = "idp display name 😀"
|
||||
const encodedIDPDisplayName = "idp+display+name+%F0%9F%98%80"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
someContext := context.Background()
|
||||
|
||||
someHttpClient := &http.Client{
|
||||
@@ -223,7 +220,7 @@ func TestGetUser(t *testing.T) {
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, gomock.Any()).Return(nil, nil)
|
||||
},
|
||||
wantUser: &upstreamprovider.GitHubUser{
|
||||
Username: "some-github-login:some-github-id",
|
||||
@@ -243,7 +240,7 @@ func TestGetUser(t *testing.T) {
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil)
|
||||
},
|
||||
wantUser: &upstreamprovider.GitHubUser{
|
||||
Username: "some-github-login",
|
||||
@@ -263,7 +260,7 @@ func TestGetUser(t *testing.T) {
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return(nil, nil)
|
||||
},
|
||||
wantUser: &upstreamprovider.GitHubUser{
|
||||
Username: "some-github-id",
|
||||
@@ -276,15 +273,15 @@ func TestGetUser(t *testing.T) {
|
||||
APIBaseURL: "https://some-url",
|
||||
HttpClient: someHttpClient,
|
||||
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
|
||||
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2"),
|
||||
},
|
||||
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
|
||||
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{
|
||||
Login: "some-github-login",
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("ALLOWED-ORG1", "ALLOWED-ORG2")).Return(nil, nil)
|
||||
},
|
||||
wantUser: &upstreamprovider.GitHubUser{
|
||||
Username: "some-github-login:some-github-id",
|
||||
@@ -297,14 +294,14 @@ func TestGetUser(t *testing.T) {
|
||||
APIBaseURL: "https://some-url",
|
||||
HttpClient: someHttpClient,
|
||||
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameID,
|
||||
AllowedOrganizations: []string{"allowed-org"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org"),
|
||||
},
|
||||
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
|
||||
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{
|
||||
Login: "some-github-login",
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("disallowed-org"), nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"disallowed-org"}, nil)
|
||||
},
|
||||
wantErr: "user is not allowed to log in due to organization membership policy",
|
||||
},
|
||||
@@ -314,7 +311,7 @@ func TestGetUser(t *testing.T) {
|
||||
APIBaseURL: "https://some-url",
|
||||
HttpClient: someHttpClient,
|
||||
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
|
||||
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"),
|
||||
GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamNameForGroupName,
|
||||
},
|
||||
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
|
||||
@@ -322,8 +319,8 @@ func TestGetUser(t *testing.T) {
|
||||
Login: "some-github-login",
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
|
||||
{
|
||||
Name: "org1-team1-name",
|
||||
Slug: "org1-team1-slug",
|
||||
@@ -353,7 +350,7 @@ func TestGetUser(t *testing.T) {
|
||||
APIBaseURL: "https://some-url",
|
||||
HttpClient: someHttpClient,
|
||||
UsernameAttribute: supervisoridpv1alpha1.GitHubUsernameLoginAndID,
|
||||
AllowedOrganizations: []string{"allowed-org1", "allowed-org2"},
|
||||
AllowedOrganizations: setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2"),
|
||||
GroupNameAttribute: supervisoridpv1alpha1.GitHubUseTeamSlugForGroupName,
|
||||
},
|
||||
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
|
||||
@@ -361,8 +358,8 @@ func TestGetUser(t *testing.T) {
|
||||
Login: "some-github-login",
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(sets.New[string]("allowed-org2"), nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"allowed-org2"}, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, setutil.NewCaseInsensitiveSet("allowed-org1", "allowed-org2")).Return([]githubclient.TeamInfo{
|
||||
{
|
||||
Name: "org1-team1-name",
|
||||
Slug: "org1-team1-slug",
|
||||
@@ -462,7 +459,7 @@ func TestGetUser(t *testing.T) {
|
||||
ID: "some-github-id",
|
||||
}, nil)
|
||||
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, sets.New[string]()).Return([]githubclient.TeamInfo{
|
||||
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, nil).Return([]githubclient.TeamInfo{
|
||||
{
|
||||
Name: "org1-team1-name",
|
||||
Slug: "org1-team1-slug",
|
||||
@@ -477,6 +474,9 @@ func TestGetUser(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
accessToken := "some-opaque-github-access-token" + rand.String(8)
|
||||
mockGitHubInterface := mockgithubclient.NewMockGitHubInterface(ctrl)
|
||||
if test.buildMockResponses != nil {
|
||||
|
||||
Reference in New Issue
Block a user