user sees error msg when GitHub login is denied due to allowed orgs

Also renamed an interface function from GetName to GetResourceName.

Co-authored-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Joshua Casey
2024-05-31 16:24:11 -05:00
committed by Ryan Richard
parent e3d8c71f97
commit 58b4ecc0aa
22 changed files with 200 additions and 89 deletions

View File

@@ -1882,7 +1882,7 @@ func TestController(t *testing.T) {
// Do not expect any particular order in the cache
var actualProvider *upstreamgithub.Provider
for _, possibleIDP := range actualIDPList {
if possibleIDP.GetName() == tt.wantResultingCache[i].Name {
if possibleIDP.GetResourceName() == tt.wantResultingCache[i].Name {
var ok bool
actualProvider, ok = possibleIDP.(*upstreamgithub.Provider)
require.True(t, ok)
@@ -1890,7 +1890,7 @@ func TestController(t *testing.T) {
}
}
require.Equal(t, tt.wantResultingCache[i].Name, actualProvider.GetName())
require.Equal(t, tt.wantResultingCache[i].Name, actualProvider.GetResourceName())
require.Equal(t, tt.wantResultingCache[i].ResourceUID, actualProvider.GetResourceUID())
require.Equal(t, tt.wantResultingCache[i].OAuth2Config.ClientID, actualProvider.GetClientID())
require.Equal(t, tt.wantResultingCache[i].GroupNameAttribute, actualProvider.GetGroupNameAttribute())

View File

@@ -1448,7 +1448,7 @@ oidc: issuer did not match the issuer returned by provider, expected "` + testIs
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList {
actualIDP := actualIDPList[i].(*upstreamoidc.ProviderConfig)
require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName())
require.Equal(t, tt.wantResultingCache[i].GetResourceName(), actualIDP.GetResourceName())
require.Equal(t, tt.wantResultingCache[i].GetClientID(), actualIDP.GetClientID())
require.Equal(t, tt.wantResultingCache[i].GetAuthorizationURL().String(), actualIDP.GetAuthorizationURL().String())
require.Equal(t, tt.wantResultingCache[i].GetUsernameClaim(), actualIDP.GetUsernameClaim())

View File

@@ -246,7 +246,7 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont
// Try to find the provider that was originally used to create the stored session.
var foundOIDCIdentityProviderI upstreamprovider.UpstreamOIDCIdentityProviderI
for _, p := range c.idpCache.GetOIDCIdentityProviders() {
if p.GetName() == customSessionData.ProviderName && p.GetResourceUID() == customSessionData.ProviderUID {
if p.GetResourceName() == customSessionData.ProviderName && p.GetResourceUID() == customSessionData.ProviderUID {
foundOIDCIdentityProviderI = p
break
}

View File

@@ -56,7 +56,7 @@ func NewPinnipedSession(
UpstreamUsername: c.UpstreamIdentity.UpstreamUsername,
UpstreamGroups: c.UpstreamIdentity.UpstreamGroups,
ProviderUID: idp.GetProvider().GetResourceUID(),
ProviderName: idp.GetProvider().GetName(),
ProviderName: idp.GetProvider().GetResourceName(),
ProviderType: idp.GetSessionProviderType(),
Warnings: c.UpstreamLoginExtras.Warnings,
}

View File

@@ -48,6 +48,9 @@ func NewHandler(
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest)
if err != nil {
plog.Error("error using state downstream auth params", err,
"identityProviderDisplayName", idp.GetDisplayName(),
"identityProviderResourceName", idp.GetProvider().GetResourceName(),
"supervisorCallbackURL", redirectURI,
"fositeErr", oidc.FositeErrorForLog(err))
return httperr.New(http.StatusBadRequest, "error using state downstream auth params")
}
@@ -59,6 +62,10 @@ func NewHandler(
identity, loginExtras, err := idp.LoginFromCallback(r.Context(), authcode(r), state.PKCECode, state.Nonce, redirectURI)
if err != nil {
plog.InfoErr("unable to complete login from callback", err,
"identityProviderDisplayName", idp.GetDisplayName(),
"identityProviderResourceName", idp.GetProvider().GetResourceName(),
"supervisorCallbackURL", redirectURI)
return err
}
@@ -69,13 +76,20 @@ func NewHandler(
GrantedScopes: authorizeRequester.GetGrantedScopes(),
})
if err != nil {
plog.InfoErr("unable to create a Pinniped session", err,
"identityProviderDisplayName", idp.GetDisplayName(),
"identityProviderResourceName", idp.GetProvider().GetResourceName(),
"supervisorCallbackURL", redirectURI)
return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err)
}
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, session)
if err != nil {
plog.WarningErr("error while generating and saving authcode", err,
"identityProviderDisplayName", idp.GetDisplayName(), "fositeErr", oidc.FositeErrorForLog(err))
"identityProviderDisplayName", idp.GetDisplayName(),
"identityProviderResourceName", idp.GetProvider().GetResourceName(),
"supervisorCallbackURL", redirectURI,
"fositeErr", oidc.FositeErrorForLog(err))
return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err)
}

View File

@@ -233,7 +233,7 @@ func findProviderByNameAndType(
idpLister federationdomainproviders.FederationDomainIdentityProvidersListerI,
) (resolvedprovider.FederationDomainResolvedIdentityProvider, error) {
for _, p := range idpLister.GetIdentityProviders() {
if p.GetSessionProviderType() == providerType && p.GetProvider().GetName() == providerResourceName {
if p.GetSessionProviderType() == providerType && p.GetProvider().GetResourceName() == providerResourceName {
if p.GetProvider().GetResourceUID() != mustHaveResourceUID {
return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication."))

View File

@@ -17,7 +17,6 @@ import (
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce"
@@ -103,7 +102,6 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) LoginFromCallback(
) (*resolvedprovider.Identity, *resolvedprovider.IdentityLoginExtras, error) {
accessToken, err := p.Provider.ExchangeAuthcode(ctx, authCode, redirectURI)
if err != nil {
plog.WarningErr("failed to exchange authcode using GitHub API", err, "upstreamName", p.Provider.GetName())
return nil, nil, httperr.Wrap(http.StatusBadGateway,
"failed to exchange authcode using GitHub API",
err,
@@ -111,8 +109,14 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) LoginFromCallback(
}
user, err := p.Provider.GetUser(ctx, accessToken, p.GetDisplayName())
if err != nil {
plog.WarningErr("failed to get user info from GitHub API", err, "upstreamName", p.Provider.GetName())
if errors.As(err, &upstreamprovider.GitHubLoginDeniedError{}) {
// We specifically want errors of type GitHubLoginDeniedError to have a user-displayed message.
// Don't wrap the error since we include it in the sprintf here.
return nil, nil, httperr.Newf(http.StatusForbidden,
"login denied due to configuration on GitHubIdentityProvider with display name %q: %s",
p.GetDisplayName(), err)
} else if err != nil {
return nil, nil, httperr.Wrap(http.StatusUnprocessableEntity,
"failed to get user info from GitHub API",
err,
@@ -151,8 +155,7 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) UpstreamRefresh(
// Get the user's GitHub identity and groups again using the cached access token.
refreshedUserInfo, err := p.Provider.GetUser(ctx, githubSessionData.UpstreamAccessToken, p.GetDisplayName())
if err != nil {
plog.WarningErr("failed to refresh user info from GitHub API", err, "upstreamName", p.Provider.GetName())
return nil, p.refreshErr(errors.New("failed to refresh user info from GitHub API"))
return nil, p.refreshErr(err)
}
if refreshedUserInfo.DownstreamSubject != identity.DownstreamSubject {
@@ -172,5 +175,5 @@ func (p *FederationDomainResolvedGitHubIdentityProvider) refreshErr(err error) *
return resolvedprovider.ErrUpstreamRefreshError().
WithHint("Upstream refresh failed.").
WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}

View File

@@ -7,6 +7,7 @@ import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/ory/fosite"
@@ -125,7 +126,9 @@ func TestLoginFromCallback(t *testing.T) {
wantGetUserArgs *oidctestutil.GetUserArgs
wantIdentity *resolvedprovider.Identity
wantExtras *resolvedprovider.IdentityLoginExtras
wantErr string
wantErrMsg string
wantErrResponseMsg string
wantErrStatusCode int
}{
{
name: "happy path",
@@ -176,13 +179,15 @@ func TestLoginFromCallback(t *testing.T) {
Authcode: "fake-authcode",
RedirectURI: "https://fake-redirect-uri",
},
wantGetUserCall: false,
wantIdentity: nil,
wantExtras: nil,
wantErr: "failed to exchange authcode using GitHub API: fake authcode exchange error",
wantGetUserCall: false,
wantIdentity: nil,
wantExtras: nil,
wantErrMsg: "failed to exchange authcode using GitHub API: fake authcode exchange error",
wantErrResponseMsg: "Bad Gateway: failed to exchange authcode using GitHub API",
wantErrStatusCode: http.StatusBadGateway,
},
{
name: "error while getting user info",
name: "generic error while getting user info",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithAccessToken("fake-access-token").
WithGetUserError(errors.New("fake user info error")).
@@ -202,9 +207,38 @@ func TestLoginFromCallback(t *testing.T) {
AccessToken: "fake-access-token",
IDPDisplayName: "fake-display-name",
},
wantIdentity: nil,
wantExtras: nil,
wantErr: "failed to get user info from GitHub API: fake user info error",
wantIdentity: nil,
wantExtras: nil,
wantErrMsg: "failed to get user info from GitHub API: fake user info error",
wantErrResponseMsg: "Unprocessable Entity: failed to get user info from GitHub API",
wantErrStatusCode: http.StatusUnprocessableEntity,
},
{
name: "loginDenied error while getting user info",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithAccessToken("fake-access-token").
WithGetUserError(upstreamprovider.NewGitHubLoginDeniedError("some login denied error")).
Build(),
idpDisplayName: "fake-display-name",
authcode: "fake-authcode",
redirectURI: "https://fake-redirect-uri",
wantExchangeAuthcodeCall: true,
wantExchangeAuthcodeArgs: &oidctestutil.ExchangeAuthcodeArgs{
Ctx: uniqueCtx,
Authcode: "fake-authcode",
RedirectURI: "https://fake-redirect-uri",
},
wantGetUserCall: true,
wantGetUserArgs: &oidctestutil.GetUserArgs{
Ctx: uniqueCtx,
AccessToken: "fake-access-token",
IDPDisplayName: "fake-display-name",
},
wantIdentity: nil,
wantExtras: nil,
wantErrMsg: `login denied due to configuration on GitHubIdentityProvider with display name "fake-display-name": some login denied error`,
wantErrResponseMsg: `Forbidden: login denied due to configuration on GitHubIdentityProvider with display name "fake-display-name": some login denied error`,
wantErrStatusCode: http.StatusForbidden,
},
}
@@ -238,12 +272,17 @@ func TestLoginFromCallback(t *testing.T) {
require.Zero(t, test.provider.GetUserCallCount())
}
if test.wantErr == "" {
if test.wantErrResponseMsg == "" {
require.NoError(t, err)
} else {
errAsResponder, ok := err.(httperr.Responder)
require.True(t, ok)
require.EqualError(t, errAsResponder, test.wantErr)
require.Implements(t, (*httperr.Responder)(nil), err)
errAsResponder := err.(httperr.Responder)
rec := httptest.NewRecorder()
errAsResponder.Respond(rec)
require.Equal(t, test.wantErrStatusCode, rec.Code)
require.Equal(t, test.wantErrResponseMsg+"\n", rec.Body.String())
require.EqualError(t, errAsResponder, test.wantErrMsg)
}
require.Equal(t, test.wantExtras, loginExtras)
require.Equal(t, test.wantIdentity, identity)
@@ -297,7 +336,7 @@ func TestUpstreamRefresh(t *testing.T) {
name: "error while getting user info",
provider: oidctestutil.NewTestUpstreamGitHubIdentityProviderBuilder().
WithName("fake-provider-name").
WithGetUserError(errors.New("any error message")).
WithGetUserError(errors.New("fake github GetUser error message")).
Build(),
identity: &resolvedprovider.Identity{
UpstreamUsername: "initial-username",
@@ -313,7 +352,7 @@ func TestUpstreamRefresh(t *testing.T) {
IDPDisplayName: "fake-display-name",
},
wantRefreshedIdentity: nil,
wantWrappedErr: "failed to refresh user info from GitHub API",
wantWrappedErr: "fake github GetUser error message",
},
{
name: "wrong session data type, which should not really happen",

View File

@@ -128,7 +128,7 @@ func (p *FederationDomainResolvedLDAPIdentityProvider) Login(
) (*resolvedprovider.Identity, *resolvedprovider.IdentityLoginExtras, error) {
authenticateResponse, authenticated, err := p.Provider.AuthenticateUser(ctx, submittedUsername, submittedPassword)
if err != nil {
plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", p.Provider.GetName())
plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", p.Provider.GetResourceName())
return nil, nil, ErrUnexpectedUpstreamLDAPError.WithWrap(err)
}
if !authenticated {
@@ -211,7 +211,7 @@ func (p *FederationDomainResolvedLDAPIdentityProvider) UpstreamRefresh(
// This shouldn't really happen.
return nil, resolvedprovider.ErrUpstreamRefreshError().WithHintf(
"Unexpected provider type during refresh %q", p.GetSessionProviderType()).WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}
if dn == "" {
@@ -219,7 +219,9 @@ func (p *FederationDomainResolvedLDAPIdentityProvider) UpstreamRefresh(
}
plog.Debug("attempting upstream refresh request",
"providerName", p.Provider.GetName(), "providerType", p.GetSessionProviderType(), "providerUID", p.Provider.GetResourceUID())
"identityProviderResourceName", p.Provider.GetResourceName(),
"identityProviderType", p.GetSessionProviderType(),
"identityProviderUID", p.Provider.GetResourceUID())
refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.LDAPRefreshAttributes{
Username: identity.UpstreamUsername,
@@ -231,7 +233,7 @@ func (p *FederationDomainResolvedLDAPIdentityProvider) UpstreamRefresh(
if err != nil {
return nil, resolvedprovider.ErrUpstreamRefreshError().WithHint(
"Upstream refresh failed.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}
return &resolvedprovider.RefreshedIdentity{

View File

@@ -191,8 +191,7 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) LoginFromCallback(
redirectURI,
)
if err != nil {
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", p.Provider.GetName())
return nil, nil, httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
return nil, nil, httperr.Wrap(http.StatusBadGateway, "error exchanging and validating upstream tokens", err)
}
subject, upstreamUsername, upstreamGroups, err := getIdentityFromUpstreamIDToken(
@@ -241,7 +240,9 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) UpstreamRefresh(
}
plog.Debug("attempting upstream refresh request",
"providerName", p.Provider.GetName(), "providerType", p.GetSessionProviderType(), "providerUID", p.Provider.GetResourceUID())
"identityProviderResourceName", p.Provider.GetResourceName(),
"identityProviderType", p.GetSessionProviderType(),
"identityProviderUID", p.Provider.GetResourceUID())
var tokens *oauth2.Token
if refreshTokenStored {
@@ -249,7 +250,7 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) UpstreamRefresh(
if err != nil {
return nil, resolvedprovider.ErrUpstreamRefreshError().WithHint(
"Upstream refresh failed.",
).WithTrace(err).WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
).WithTrace(err).WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}
} else {
tokens = &oauth2.Token{AccessToken: sessionData.UpstreamAccessToken}
@@ -270,13 +271,13 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) UpstreamRefresh(
if err != nil {
return nil, resolvedprovider.ErrUpstreamRefreshError().WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}
mergedClaims := validatedTokens.IDToken.Claims
// To the extent possible, check that the user's basic identity hasn't changed. We check that their downstream
// username has not changed separately below, as part of reapplying the transformations.
err = validateUpstreamSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims, sessionData, p.Provider.GetName(), p.GetSessionProviderType())
err = validateUpstreamSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims, sessionData, p.Provider.GetResourceName(), p.GetSessionProviderType())
if err != nil {
return nil, err
}
@@ -292,7 +293,7 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) UpstreamRefresh(
if err != nil {
return nil, resolvedprovider.ErrUpstreamRefreshError().WithHintf(
"Upstream refresh error while extracting groups claim.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetName(), p.GetSessionProviderType())
WithDebugf("provider name: %q, provider type: %q", p.Provider.GetResourceName(), p.GetSessionProviderType())
}
// It's possible that a username wasn't returned by the upstream provider during refresh,
@@ -312,7 +313,9 @@ func (p *FederationDomainResolvedOIDCIdentityProvider) UpstreamRefresh(
// overwriting the old one.
if tokens.RefreshToken != "" {
plog.Debug("upstream refresh request returned a new refresh token",
"providerName", p.Provider.GetName(), "providerType", p.GetSessionProviderType(), "providerUID", p.Provider.GetResourceUID())
"identityProviderResourceName", p.Provider.GetResourceName(),
"identityProviderType", p.GetSessionProviderType(),
"identityProviderUID", p.Provider.GetResourceUID())
updatedSessionData.UpstreamRefreshToken = tokens.RefreshToken
}
@@ -370,11 +373,11 @@ func makeDownstreamOIDCSessionData(
oidcUpstream upstreamprovider.UpstreamOIDCIdentityProviderI,
token *oidctypes.Token,
) (*psession.OIDCSessionData, []string, error) {
upstreamSubject, err := extractStringClaimValue(oidc.IDTokenClaimSubject, oidcUpstream.GetName(), token.IDToken.Claims)
upstreamSubject, err := extractStringClaimValue(oidc.IDTokenClaimSubject, oidcUpstream.GetResourceName(), token.IDToken.Claims)
if err != nil {
return nil, nil, err
}
upstreamIssuer, err := extractStringClaimValue(oidc.IDTokenClaimIssuer, oidcUpstream.GetName(), token.IDToken.Claims)
upstreamIssuer, err := extractStringClaimValue(oidc.IDTokenClaimIssuer, oidcUpstream.GetResourceName(), token.IDToken.Claims)
if err != nil {
return nil, nil, err
}
@@ -387,7 +390,7 @@ func makeDownstreamOIDCSessionData(
const pleaseCheck = "please check configuration of OIDCIdentityProvider and the client in the " +
"upstream provider's API/UI and try to get a refresh token if possible"
logKV := []interface{}{
"upstreamName", oidcUpstream.GetName(),
"identityProviderResourceName", oidcUpstream.GetResourceName(),
"scopes", oidcUpstream.GetScopes(),
"additionalParams", oidcUpstream.GetAdditionalAuthcodeParams(),
}
@@ -452,7 +455,7 @@ func mapAdditionalClaimsFromUpstreamIDToken(
if !ok {
plog.Warning(
"additionalClaims mapping claim in upstream ID token missing",
"upstreamName", upstreamIDPConfig.GetName(),
"identityProviderResourceName", upstreamIDPConfig.GetResourceName(),
"claimName", upstreamClaimName,
)
} else {
@@ -469,11 +472,11 @@ func getDownstreamSubjectAndUpstreamUsernameFromUpstreamIDToken(
) (string, string, error) {
// The spec says the "sub" claim is only unique per issuer,
// so we will prepend the issuer string to make it globally unique.
upstreamIssuer, err := extractStringClaimValue(oidc.IDTokenClaimIssuer, upstreamIDPConfig.GetName(), idTokenClaims)
upstreamIssuer, err := extractStringClaimValue(oidc.IDTokenClaimIssuer, upstreamIDPConfig.GetResourceName(), idTokenClaims)
if err != nil {
return "", "", err
}
upstreamSubject, err := extractStringClaimValue(oidc.IDTokenClaimSubject, upstreamIDPConfig.GetName(), idTokenClaims)
upstreamSubject, err := extractStringClaimValue(oidc.IDTokenClaimSubject, upstreamIDPConfig.GetResourceName(), idTokenClaims)
if err != nil {
return "", "", err
}
@@ -492,7 +495,7 @@ func getDownstreamSubjectAndUpstreamUsernameFromUpstreamIDToken(
if !ok {
plog.Warning(
"username claim configured as \"email\" and upstream email_verified claim is not a boolean",
"upstreamName", upstreamIDPConfig.GetName(),
"identityProviderResourceName", upstreamIDPConfig.GetResourceName(),
"configuredUsernameClaim", usernameClaimName,
"emailVerifiedClaim", emailVerifiedAsInterface,
)
@@ -501,14 +504,14 @@ func getDownstreamSubjectAndUpstreamUsernameFromUpstreamIDToken(
if !emailVerified {
plog.Warning(
"username claim configured as \"email\" and upstream email_verified claim has false value",
"upstreamName", upstreamIDPConfig.GetName(),
"identityProviderResourceName", upstreamIDPConfig.GetResourceName(),
"configuredUsernameClaim", usernameClaimName,
)
return "", "", emailVerifiedClaimFalseErr
}
}
username, err := extractStringClaimValue(usernameClaimName, upstreamIDPConfig.GetName(), idTokenClaims)
username, err := extractStringClaimValue(usernameClaimName, upstreamIDPConfig.GetResourceName(), idTokenClaims)
if err != nil {
return "", "", err
}
@@ -571,7 +574,7 @@ func getGroupsFromUpstreamIDToken(
if !ok {
plog.Warning(
"no groups claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
"identityProviderResourceName", upstreamIDPConfig.GetResourceName(),
"configuredGroupsClaim", groupsClaimName,
)
return nil, nil // the upstream IDP may have omitted the claim if the user has no groups
@@ -581,7 +584,7 @@ func getGroupsFromUpstreamIDToken(
if !okAsArray {
plog.Warning(
"groups claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
"identityProviderResourceName", upstreamIDPConfig.GetResourceName(),
"configuredGroupsClaim", groupsClaimName,
)
return nil, requiredClaimInvalidFormatErr

View File

@@ -39,11 +39,11 @@ type LDAPRefreshAttributes struct {
// UpstreamIdentityProviderI includes the interface functions that are common to all upstream identity provider types.
// These represent the identity provider resources, i.e. OIDCIdentityProvider, etc.
type UpstreamIdentityProviderI interface {
// GetName returns a name for this upstream provider. The controller watching the identity provider resources will
// GetResourceName returns a name for this upstream provider. The controller watching the identity provider resources will
// set this to be the Name of the CR from its metadata. Note that this is different from the DisplayName configured
// in each FederationDomain that uses this provider, so this name is for internal use only, not for interacting
// with clients. Clients should not expect to see this name or send this name.
GetName() string
GetResourceName() string
// GetResourceUID returns the Kubernetes resource ID
GetResourceUID() types.UID
@@ -134,6 +134,22 @@ type GitHubUser struct {
DownstreamSubject string // the whole downstream subject URI
}
// GitHubLoginDeniedError can be returned by UpstreamGithubIdentityProviderI GetUser() when a policy
// configured on GitHubIdentityProvider should prevent this user from completing authentication.
type GitHubLoginDeniedError struct {
message string
}
func NewGitHubLoginDeniedError(message string) GitHubLoginDeniedError {
return GitHubLoginDeniedError{message: message}
}
func (g GitHubLoginDeniedError) Error() string {
return g.message
}
var _ error = &GitHubLoginDeniedError{}
type UpstreamGithubIdentityProviderI interface {
UpstreamIdentityProviderI

View File

@@ -150,9 +150,9 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetGroupsClaim() *gomoc
}
// GetName mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) GetName() string {
func (m *MockUpstreamOIDCIdentityProviderI) GetResourceName() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetName")
ret := m.ctrl.Call(m, "GetResourceName")
ret0, _ := ret[0].(string)
return ret0
}
@@ -160,7 +160,7 @@ func (m *MockUpstreamOIDCIdentityProviderI) GetName() string {
// GetName indicates an expected call of GetName.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetName() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetName))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResourceName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetResourceName))
}
// GetResourceUID mocks base method.

View File

@@ -12,6 +12,7 @@
//
// info should be reserved for "nice to know" information. It should be possible to run a production
// pinniped server at the info log level with no performance degradation due to high log volume.
//
// debug should be used for information targeted at developers and to aid in support cases. Care must
// be taken at this level to not leak any secrets into the log stream. That is, even though debug may
// cause performance issues in production, it must not cause security issues in production.

View File

@@ -183,7 +183,7 @@ func (u *TestUpstreamGitHubIdentityProvider) GetResourceUID() types.UID {
return u.ResourceUID
}
func (u *TestUpstreamGitHubIdentityProvider) GetName() string {
func (u *TestUpstreamGitHubIdentityProvider) GetResourceName() string {
return u.Name
}

View File

@@ -117,7 +117,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetResourceUID() types.UID {
return u.ResourceUID
}
func (u *TestUpstreamLDAPIdentityProvider) GetName() string {
func (u *TestUpstreamLDAPIdentityProvider) GetResourceName() string {
return u.Name
}

View File

@@ -123,7 +123,7 @@ func (u *TestUpstreamOIDCIdentityProvider) GetAdditionalClaimMappings() map[stri
return u.AdditionalClaimMappings
}
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
func (u *TestUpstreamOIDCIdentityProvider) GetResourceName() string {
return u.Name
}

View File

@@ -6,7 +6,6 @@ package upstreamgithub
import (
"context"
"errors"
"fmt"
"net/http"
@@ -18,6 +17,7 @@ import (
"go.pinniped.dev/internal/federationdomain/downstreamsubject"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/githubclient"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/setutil"
)
@@ -66,7 +66,7 @@ func New(config ProviderConfig) *Provider {
}
}
func (p *Provider) GetName() string {
func (p *Provider) GetResourceName() string {
return p.c.Name
}
@@ -147,7 +147,20 @@ func (p *Provider) GetUser(ctx context.Context, accessToken string, idpDisplayNa
}
if !p.c.AllowedOrganizations.Empty() && !p.c.AllowedOrganizations.HasAnyIgnoringCase(orgMembership) {
return nil, errors.New("user is not allowed to log in due to organization membership policy")
plog.Warning("user is not allowed to log in due to organization membership policy", // do not log username to avoid PII
"userBelongsToOrganizations", orgMembership,
"configuredAllowedOrganizations", p.c.AllowedOrganizations,
"identityProviderDisplayName", idpDisplayName,
"identityProviderResourceName", p.GetResourceName())
plog.Trace("user is not allowed to log in due to organization membership policy", // okay to log PII at trace level
"githubLogin", userInfo.Login,
"githubID", userInfo.ID,
"calculatedUsername", githubUser.Username,
"userBelongsToOrganizations", orgMembership,
"configuredAllowedOrganizations", p.c.AllowedOrganizations,
"identityProviderDisplayName", idpDisplayName,
"identityProviderResourceName", p.GetResourceName())
return nil, upstreamprovider.NewGitHubLoginDeniedError("user is not allowed to log in due to organization membership policy")
}
teamMembership, err := githubClient.GetTeamMembership(ctx, p.c.AllowedOrganizations)

View File

@@ -74,7 +74,7 @@ func TestGitHubProvider(t *testing.T) {
},
}, subject.GetConfig())
require.Equal(t, "foo", subject.GetName())
require.Equal(t, "foo", subject.GetResourceName())
require.Equal(t, types.UID("resource-uid-12345"), subject.GetResourceUID())
require.Equal(t, "fake-client-id", subject.GetClientID())
require.Equal(t, "fake-client-id", subject.GetClientID())
@@ -205,7 +205,8 @@ func TestGetUser(t *testing.T) {
buildGitHubClientError error
buildMockResponses func(hubInterface *mockgithubclient.MockGitHubInterface)
wantUser *upstreamprovider.GitHubUser
wantErr string
wantErrMsg string
wantErr error
}{
{
name: "happy path with username=login:id",
@@ -303,7 +304,7 @@ func TestGetUser(t *testing.T) {
}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return([]string{"disallowed-org"}, nil)
},
wantErr: "user is not allowed to log in due to organization membership policy",
wantErr: upstreamprovider.NewGitHubLoginDeniedError("user is not allowed to log in due to organization membership policy"),
},
{
name: "happy path with groups=name",
@@ -390,7 +391,7 @@ func TestGetUser(t *testing.T) {
HttpClient: someHttpClient,
},
buildGitHubClientError: errors.New("error from building a github client"),
wantErr: "error from building a github client",
wantErrMsg: "error from building a github client",
},
{
name: "returns errors from githubClient.GetUserInfo()",
@@ -401,7 +402,7 @@ func TestGetUser(t *testing.T) {
buildMockResponses: func(mockGitHubInterface *mockgithubclient.MockGitHubInterface) {
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(nil, errors.New("error from githubClient.GetUserInfo"))
},
wantErr: "error from githubClient.GetUserInfo",
wantErrMsg: "error from githubClient.GetUserInfo",
},
{
name: "returns errors from githubClient.GetOrgMembership()",
@@ -414,7 +415,7 @@ func TestGetUser(t *testing.T) {
mockGitHubInterface.EXPECT().GetUserInfo(someContext).Return(&githubclient.UserInfo{}, nil)
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, errors.New("error from githubClient.GetOrgMembership"))
},
wantErr: "error from githubClient.GetOrgMembership",
wantErrMsg: "error from githubClient.GetOrgMembership",
},
{
name: "returns errors from githubClient.GetTeamMembership()",
@@ -428,7 +429,7 @@ func TestGetUser(t *testing.T) {
mockGitHubInterface.EXPECT().GetOrgMembership(someContext).Return(nil, nil)
mockGitHubInterface.EXPECT().GetTeamMembership(someContext, gomock.Any()).Return(nil, errors.New("error from githubClient.GetTeamMembership"))
},
wantErr: "error from githubClient.GetTeamMembership",
wantErrMsg: "error from githubClient.GetTeamMembership",
},
{
name: "bad configuration: UsernameAttribute",
@@ -443,7 +444,7 @@ func TestGetUser(t *testing.T) {
ID: "some-github-id",
}, nil)
},
wantErr: "bad configuration: unknown GitHub username attribute: this-is-not-legal-value-from-the-enum",
wantErrMsg: "bad configuration: unknown GitHub username attribute: this-is-not-legal-value-from-the-enum",
},
{
name: "bad configuration: GroupNameAttribute",
@@ -467,7 +468,7 @@ func TestGetUser(t *testing.T) {
},
}, nil)
},
wantErr: "bad configuration: unknown GitHub group name attribute: this-is-not-legal-value-from-the-enum",
wantErrMsg: "bad configuration: unknown GitHub group name attribute: this-is-not-legal-value-from-the-enum",
},
}
for _, test := range tests {
@@ -493,13 +494,18 @@ func TestGetUser(t *testing.T) {
}
actualUser, actualErr := p.GetUser(context.Background(), accessToken, idpDisplayName)
if test.wantErr != "" {
require.EqualError(t, actualErr, test.wantErr)
switch {
case test.wantErrMsg != "":
require.EqualError(t, actualErr, test.wantErrMsg)
require.Nil(t, actualUser)
return
case test.wantErr != nil:
require.Equal(t, test.wantErr, actualErr)
require.Nil(t, actualUser)
default:
require.NoError(t, actualErr)
require.Equal(t, test.wantUser, actualUser)
}
require.NoError(t, actualErr)
require.Equal(t, test.wantUser, actualUser)
})
}
}

View File

@@ -187,7 +187,7 @@ func closeAndLogError(conn Conn, doingWhat string) {
}
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes upstreamprovider.LDAPRefreshAttributes, idpDisplayName string) ([]string, error) {
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetResourceName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
userDN := storedRefreshAttributes.DN
@@ -373,7 +373,7 @@ func (p *Provider) tlsConfig() (*tls.Config, error) {
}
// GetName returns a name for this upstream provider.
func (p *Provider) GetName() string {
func (p *Provider) GetResourceName() string {
return p.c.Name
}
@@ -435,7 +435,7 @@ func (p *Provider) AuthenticateUser(ctx context.Context, username, password stri
}
func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bindFunc func(conn Conn, foundUserDN string) error) (*authenticators.Response, bool, error) {
t := trace.FromContext(ctx).Nest("slow ldap authenticate user attempt", trace.Field{Key: "providerName", Value: p.GetName()})
t := trace.FromContext(ctx).Nest("slow ldap authenticate user attempt", trace.Field{Key: "providerName", Value: p.GetResourceName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
err := p.validateConfig()
@@ -528,7 +528,7 @@ func (p *Provider) validateConfig() error {
}
func (p *Provider) SearchForDefaultNamingContext(ctx context.Context) (string, error) {
t := trace.FromContext(ctx).Nest("slow ldap attempt when searching for default naming context", trace.Field{Key: "providerName", Value: p.GetName()})
t := trace.FromContext(ctx).Nest("slow ldap attempt when searching for default naming context", trace.Field{Key: "providerName", Value: p.GetResourceName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
conn, err := p.dial(ctx)
@@ -564,7 +564,7 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
searchResult, err := conn.Search(p.userSearchRequest(username))
if err != nil {
plog.All(`error searching for user`,
"upstreamName", p.GetName(),
"upstreamName", p.GetResourceName(),
"username", username,
"err", err,
)
@@ -573,11 +573,11 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
if len(searchResult.Entries) == 0 {
if plog.Enabled(plog.LevelAll) {
plog.All("error finding user: user not found (if this username is valid, please check the user search configuration)",
"upstreamName", p.GetName(),
"upstreamName", p.GetResourceName(),
"username", username,
)
} else {
plog.Debug("error finding user: user not found (cowardly avoiding printing username because log level is not 'all')", "upstreamName", p.GetName())
plog.Debug("error finding user: user not found (cowardly avoiding printing username because log level is not 'all')", "upstreamName", p.GetResourceName())
}
return nil, nil
}
@@ -632,7 +632,7 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
err = bindFunc(conn, userEntry.DN)
if err != nil {
plog.DebugErr("error binding for user (if this is not the expected dn for this username, please check the user search configuration)",
err, "upstreamName", p.GetName(), "username", username, "dn", userEntry.DN)
err, "upstreamName", p.GetResourceName(), "username", username, "dn", userEntry.DN)
ldapErr := &ldap.Error{}
if errors.As(err, &ldapErr) && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
return nil, nil

View File

@@ -1,4 +1,4 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2024 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions.
@@ -84,7 +84,7 @@ func (p *ProviderConfig) GetAdditionalClaimMappings() map[string]string {
return p.AdditionalClaimMappings
}
func (p *ProviderConfig) GetName() string {
func (p *ProviderConfig) GetResourceName() string {
return p.Name
}

View File

@@ -47,7 +47,7 @@ func TestProviderConfig(t *testing.T) {
rawClaims: []byte(`{"userinfo_endpoint": "https://example.com/userinfo"}`),
},
}
require.Equal(t, "test-name", p.GetName())
require.Equal(t, "test-name", p.GetResourceName())
require.Equal(t, "test-client-id", p.GetClientID())
require.Equal(t, "https://example.com", p.GetAuthorizationURL().String())
require.ElementsMatch(t, []string{"scope1", "scope2"}, p.GetScopes())

View File

@@ -2358,6 +2358,20 @@ func supervisorLoginGithubTestcases(
}
return testlib.CreateTestGitHubIdentityProvider(t, spec, idpv1alpha1.GitHubPhaseReady).Name
},
federationDomainIDPs: func(t *testing.T, idpName string) ([]configv1alpha1.FederationDomainIdentityProvider, string) {
displayName := "some-github-identity-provider-name"
return []configv1alpha1.FederationDomainIdentityProvider{
{
DisplayName: displayName,
ObjectRef: corev1.TypedLocalObjectReference{
APIGroup: ptr.To("idp.supervisor." + env.APIGroupSuffix),
Kind: "GitHubIdentityProvider",
Name: idpName,
},
},
},
displayName
},
requestAuthorization: func(t *testing.T, _, downstreamAuthorizeURL, downstreamCallbackURL, _, _ string, httpClient *http.Client) {
t.Helper()
browser := openBrowserAndNavigateToAuthorizeURL(t, downstreamAuthorizeURL, httpClient)
@@ -2370,7 +2384,7 @@ func supervisorLoginGithubTestcases(
// Get the text of the preformatted error message showing on the page.
textOfPreTag := browser.TextOfFirstMatch(t, "pre")
require.Equal(t,
"Unprocessable Entity: failed to get user info from GitHub API\n",
`Forbidden: login denied due to configuration on GitHubIdentityProvider with display name "some-github-identity-provider-name": user is not allowed to log in due to organization membership policy`+"\n",
textOfPreTag)
},
wantLocalhostCallbackToNeverHappen: true,