diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 3598c35c2..6a6e6431e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -258,7 +258,7 @@ func getGroupsFromUpstreamIDToken( ) ([]string, error) { groupsClaim := upstreamIDPConfig.GetGroupsClaim() if groupsClaim == "" { - return []string{}, nil + return nil, nil } groupsAsInterface, ok := idTokenClaims[groupsClaim] @@ -269,7 +269,7 @@ func getGroupsFromUpstreamIDToken( "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), "groupsClaim", groupsClaim, ) - return []string{}, nil // the upstream IDP may have omitted the claim if the user has no groups + return nil, nil // the upstream IDP may have omitted the claim if the user has no groups } groupsAsArray, okAsArray := extractGroups(groupsAsInterface) @@ -302,13 +302,15 @@ func extractGroups(groupsAsInterface interface{}) ([]string, bool) { return nil, false } - groupsAsStrings := make([]string, len(groupsAsInterfaceArray)) - for i, groupAsInterface := range groupsAsInterfaceArray { + var groupsAsStrings []string + for _, groupAsInterface := range groupsAsInterfaceArray { groupAsString, okAsString := groupAsInterface.(string) if !okAsString { return nil, false } - groupsAsStrings[i] = groupAsString + if groupAsString != "" { + groupsAsStrings = append(groupsAsStrings, groupAsString) + } } return groupsAsStrings, true @@ -323,6 +325,9 @@ func makeDownstreamSession(subject string, username string, groups []string) *op AuthTime: now, }, } + if groups == nil { + groups = []string{} + } openIDSession.Claims.Extra = map[string]interface{}{ oidc.DownstreamUsernameClaim: username, oidc.DownstreamGroupsClaim: groups, diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 335ca5ae9..43eccb155 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -102,12 +102,12 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e if err := validated.Claims(&validatedClaims); err != nil { return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err) } - plog.All("claims from ID token", "providerName", p.Name, "claims", listClaims(validatedClaims)) + plog.All("claims from ID token", "providerName", p.Name, "claims", validatedClaims) if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil { return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) } - plog.All("claims from ID token and userinfo", "providerName", p.Name, "claims", listClaims(validatedClaims)) + plog.All("claims from ID token and userinfo", "providerName", p.Name, "claims", validatedClaims) return &oidctypes.Token{ AccessToken: &oidctypes.AccessToken{ @@ -162,13 +162,3 @@ func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, c return nil } - -func listClaims(claims map[string]interface{}) []string { - list := make([]string, len(claims)) - i := 0 - for claim := range claims { - list[i] = claim - i++ - } - return list -}