mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-06 13:36:54 +00:00
Addressing PR feedback
store issuer and subject in storage for refresh Clean up some constants Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
|
||||
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions.
|
||||
@@ -7,7 +7,6 @@ package upstreamoidc
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -238,53 +237,19 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) {
|
||||
if !strings.Contains(downstreamSubject, "?sub=") {
|
||||
return "", "", errors.New("downstream subject did not contain original upstream subject")
|
||||
}
|
||||
split := strings.SplitN(downstreamSubject, "?sub=", 2)
|
||||
iss := split[0]
|
||||
sub := split[1]
|
||||
if iss == "" || sub == "" {
|
||||
return "", "", errors.New("downstream subject was malformed")
|
||||
}
|
||||
return split[0], split[1], nil
|
||||
}
|
||||
|
||||
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response,
|
||||
// if the provider offers the userinfo endpoint.
|
||||
func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) {
|
||||
var validatedClaims = make(map[string]interface{})
|
||||
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
var idTokenExpiry time.Time
|
||||
// if we require the id token, make sure we have it.
|
||||
// also, if it exists but wasn't required, still make sure it passes these checks.
|
||||
// nolint:nestif
|
||||
if hasIDTok || requireIDToken {
|
||||
if !hasIDTok {
|
||||
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := p.Provider.Verifier(&coreosoidc.Config{ClientID: p.GetClientID()}).Verify(coreosoidc.ClientContext(ctx, p.Client), idTok)
|
||||
if err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if expectedIDTokenNonce != "" {
|
||||
if err := expectedIDTokenNonce.Validate(validated); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
if err := validated.Claims(&validatedClaims); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
|
||||
}
|
||||
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
|
||||
idTokenExpiry = validated.Expiry // keep track of the id token expiry if we have an id token. Otherwise, it'll just be the zero value.
|
||||
idTokenExpiry, idTok, err := p.validateIDToken(ctx, tok, expectedIDTokenNonce, validatedClaims, requireIDToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idTokenSubject, _ := validatedClaims[oidc.IDTokenSubjectClaim].(string)
|
||||
|
||||
if len(idTokenSubject) > 0 || !requireIDToken {
|
||||
@@ -310,10 +275,42 @@ func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) validateIDToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, validatedClaims map[string]interface{}, requireIDToken bool) (time.Time, string, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok && !requireIDToken {
|
||||
return time.Time{}, "", nil // exit early
|
||||
}
|
||||
|
||||
var idTokenExpiry time.Time
|
||||
if !hasIDTok {
|
||||
return time.Time{}, "", httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := p.Provider.Verifier(&coreosoidc.Config{ClientID: p.GetClientID()}).Verify(coreosoidc.ClientContext(ctx, p.Client), idTok)
|
||||
if err != nil {
|
||||
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if expectedIDTokenNonce != "" {
|
||||
if err := expectedIDTokenNonce.Validate(validated); err != nil {
|
||||
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
if err := validated.Claims(&validatedClaims); err != nil {
|
||||
return time.Time{}, "", httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
|
||||
}
|
||||
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
|
||||
idTokenExpiry = validated.Expiry // keep track of the id token expiry if we have an id token. Otherwise, it'll just be the zero value.
|
||||
return idTokenExpiry, idTok, nil
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool) error {
|
||||
idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string)
|
||||
|
||||
userInfo, err := p.fetchUserInfo(ctx, tok)
|
||||
userInfo, err := p.maybeFetchUserInfo(ctx, tok)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -356,7 +353,7 @@ func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) {
|
||||
func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) {
|
||||
providerJSON := &struct {
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
}{}
|
||||
|
||||
@@ -910,63 +910,6 @@ func TestProviderConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
downstreamSubject string
|
||||
wantUpstreamSubject string
|
||||
wantUpstreamIssuer string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
downstreamSubject: "https://some-issuer?sub=some-subject",
|
||||
wantUpstreamSubject: "some-subject",
|
||||
wantUpstreamIssuer: "https://some-issuer",
|
||||
},
|
||||
{
|
||||
name: "subject in a subject",
|
||||
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
|
||||
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
|
||||
wantUpstreamIssuer: "https://some-other-issuer",
|
||||
},
|
||||
{
|
||||
name: "sub is empty string",
|
||||
downstreamSubject: "https://some-issuer?sub=",
|
||||
wantErr: "downstream subject was malformed",
|
||||
},
|
||||
{
|
||||
name: "iss is empty string",
|
||||
downstreamSubject: "?sub=some-subject",
|
||||
wantErr: "downstream subject was malformed",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
downstreamSubject: "",
|
||||
wantErr: "downstream subject did not contain original upstream subject",
|
||||
},
|
||||
{
|
||||
name: "doesn't contain sub=",
|
||||
downstreamSubject: "something-invalid",
|
||||
wantErr: "downstream subject did not contain original upstream subject",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.Equal(t, tt.wantErr, err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
|
||||
require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user