Addressing PR feedback

store issuer and subject in storage for refresh
Clean up some constants

Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
Margo Crawford
2022-01-07 15:04:58 -08:00
parent f2d2144932
commit 2958461970
21 changed files with 178 additions and 191 deletions

View File

@@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions.
@@ -7,7 +7,6 @@ package upstreamoidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -238,53 +237,19 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
}
}
func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) {
if !strings.Contains(downstreamSubject, "?sub=") {
return "", "", errors.New("downstream subject did not contain original upstream subject")
}
split := strings.SplitN(downstreamSubject, "?sub=", 2)
iss := split[0]
sub := split[1]
if iss == "" || sub == "" {
return "", "", errors.New("downstream subject was malformed")
}
return split[0], split[1], nil
}
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response,
// if the provider offers the userinfo endpoint.
func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) {
var validatedClaims = make(map[string]interface{})
idTok, hasIDTok := tok.Extra("id_token").(string)
var idTokenExpiry time.Time
// if we require the id token, make sure we have it.
// also, if it exists but wasn't required, still make sure it passes these checks.
// nolint:nestif
if hasIDTok || requireIDToken {
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := p.Provider.Verifier(&coreosoidc.Config{ClientID: p.GetClientID()}).Verify(coreosoidc.ClientContext(ctx, p.Client), idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if expectedIDTokenNonce != "" {
if err := expectedIDTokenNonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
if err := validated.Claims(&validatedClaims); err != nil {
return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
}
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
idTokenExpiry = validated.Expiry // keep track of the id token expiry if we have an id token. Otherwise, it'll just be the zero value.
idTokenExpiry, idTok, err := p.validateIDToken(ctx, tok, expectedIDTokenNonce, validatedClaims, requireIDToken)
if err != nil {
return nil, err
}
idTokenSubject, _ := validatedClaims[oidc.IDTokenSubjectClaim].(string)
if len(idTokenSubject) > 0 || !requireIDToken {
@@ -310,10 +275,42 @@ func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context,
}, nil
}
func (p *ProviderConfig) validateIDToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, validatedClaims map[string]interface{}, requireIDToken bool) (time.Time, string, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok && !requireIDToken {
return time.Time{}, "", nil // exit early
}
var idTokenExpiry time.Time
if !hasIDTok {
return time.Time{}, "", httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := p.Provider.Verifier(&coreosoidc.Config{ClientID: p.GetClientID()}).Verify(coreosoidc.ClientContext(ctx, p.Client), idTok)
if err != nil {
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if expectedIDTokenNonce != "" {
if err := expectedIDTokenNonce.Validate(validated); err != nil {
return time.Time{}, "", httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
if err := validated.Claims(&validatedClaims); err != nil {
return time.Time{}, "", httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
}
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
idTokenExpiry = validated.Expiry // keep track of the id token expiry if we have an id token. Otherwise, it'll just be the zero value.
return idTokenExpiry, idTok, nil
}
func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool) error {
idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string)
userInfo, err := p.fetchUserInfo(ctx, tok)
userInfo, err := p.maybeFetchUserInfo(ctx, tok)
if err != nil {
return err
}
@@ -356,7 +353,7 @@ func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, t
return nil
}
func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) {
func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) {
providerJSON := &struct {
UserInfoURL string `json:"userinfo_endpoint"`
}{}

View File

@@ -910,63 +910,6 @@ func TestProviderConfig(t *testing.T) {
}
})
t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) {
tests := []struct {
name string
downstreamSubject string
wantUpstreamSubject string
wantUpstreamIssuer string
wantErr string
}{
{
name: "happy path",
downstreamSubject: "https://some-issuer?sub=some-subject",
wantUpstreamSubject: "some-subject",
wantUpstreamIssuer: "https://some-issuer",
},
{
name: "subject in a subject",
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
wantUpstreamIssuer: "https://some-other-issuer",
},
{
name: "sub is empty string",
downstreamSubject: "https://some-issuer?sub=",
wantErr: "downstream subject was malformed",
},
{
name: "iss is empty string",
downstreamSubject: "?sub=some-subject",
wantErr: "downstream subject was malformed",
},
{
name: "empty string",
downstreamSubject: "",
wantErr: "downstream subject did not contain original upstream subject",
},
{
name: "doesn't contain sub=",
downstreamSubject: "something-invalid",
wantErr: "downstream subject did not contain original upstream subject",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject)
if tt.wantErr != "" {
require.Error(t, err)
require.Equal(t, tt.wantErr, err.Error())
} else {
require.NoError(t, err)
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer)
}
})
}
})
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
tests := []struct {
name string