diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 6aae7cc9c..be568731b 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -354,12 +354,10 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er func (h *handlerState) needRFC8693TokenExchange(token *oidctypes.Token) bool { // Need a new ID token if there is a requested audience value and any of the following are true... return h.requestedAudience != "" && - // we don't have an ID token + // we don't have an ID token (maybe it expired or was otherwise removed from the session cache) (token.IDToken == nil || - // or, our current ID token has expired or is close to expiring - idTokenExpiredOrCloseToExpiring(token.IDToken) || // or, our current ID token has a different audience - (h.requestedAudience != token.IDToken.Claims["aud"])) + h.requestedAudience != token.IDToken.Claims["aud"]) } func (h *handlerState) tokenValidForNearFuture(token *oidctypes.Token) (bool, string) { diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index f08298f03..a908c94c3 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -75,6 +75,8 @@ func newClientForServer(server *httptest.Server) *http.Client { } func TestLogin(t *testing.T) { //nolint:gocyclo + fakeUniqueTime := time.Now().Add(6 * time.Minute).Add(6 * time.Second) + distantFutureTime := time.Date(2065, 10, 12, 13, 14, 15, 16, time.UTC) testCodeChallenge := testutil.SHA256("test-pkce") @@ -1985,7 +1987,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo }, }, { - name: "with requested audience, session cache hit with valid access token, ID token already has the requested audience, but ID token is expired", + name: "with requested audience, session cache hit with valid access token, ID token already has the requested audience, but ID token is expired, causes a refresh and uses refreshed ID token", issuer: successServer.URL, clientID: "test-client-id", opt: func(t *testing.T) Option { @@ -1995,7 +1997,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo IDToken: &oidctypes.IDToken{ Token: testToken.IDToken.Token, Expiry: metav1.NewTime(time.Now().Add(9 * time.Minute)), // less than Now() + minIDTokenValidity - Claims: map[string]interface{}{"aud": "request-this-test-audience"}, + Claims: map[string]interface{}{"aud": "test-custom-request-audience"}, }, RefreshToken: testToken.RefreshToken, }} @@ -2006,26 +2008,56 @@ func TestLogin(t *testing.T) { //nolint:gocyclo Scopes: []string{"test-scope"}, RedirectURI: "http://localhost:0/callback", }}, cache.sawGetKeys) - require.Empty(t, cache.sawPutTokens) + require.Len(t, cache.sawPutTokens, 1) + // want to have cached the refreshed ID token + require.Equal(t, &oidctypes.IDToken{ + Token: testToken.IDToken.Token, + Expiry: metav1.NewTime(fakeUniqueTime), + Claims: map[string]interface{}{"aud": "test-custom-request-audience"}, + }, cache.sawPutTokens[0].IDToken) }) require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithSessionCache(cache)(h)) - require.NoError(t, WithRequestAudience("request-this-test-audience")(h)) + require.NoError(t, WithRequestAudience("test-custom-request-audience")(h)) - h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) { - require.Equal(t, "request-this-test-audience", audience) - require.Equal(t, "test-id-token-with-requested-audience", token) - return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). + Return(&oidctypes.Token{ + AccessToken: testToken.AccessToken, + IDToken: &oidctypes.IDToken{ + Token: testToken.IDToken.Token, + Expiry: metav1.NewTime(fakeUniqueTime), // less than Now() + minIDTokenValidity but does not matter because this is a freshly refreshed ID token + Claims: map[string]interface{}{"aud": "test-custom-request-audience"}, + }, + RefreshToken: testToken.RefreshToken, + }, nil) + mock.EXPECT(). + PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) + return mock } return nil } }, wantLogs: []string{ - `"level"=4 "msg"="Pinniped: Found unexpired cached token." "type"="access_token"`, - `"level"=4 "msg"="Pinniped: Performing RFC8693 token exchange" "requestedAudience"="request-this-test-audience"`, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`, + `"level"=4 "msg"="Pinniped: Refreshing cached tokens."`, + }, + // want to have returned the refreshed tokens + wantToken: &oidctypes.Token{ + AccessToken: testToken.AccessToken, + IDToken: &oidctypes.IDToken{ + Token: testToken.IDToken.Token, + Expiry: metav1.NewTime(fakeUniqueTime), + Claims: map[string]interface{}{"aud": "test-custom-request-audience"}, + }, + RefreshToken: testToken.RefreshToken, }, - wantToken: &testExchangedToken, }, { name: "with requested audience, session cache hit with valid access token, but no ID token",