diff --git a/internal/controller/supervisorstorage/garbage_collector.go b/internal/controller/supervisorstorage/garbage_collector.go index 2ce06533d..5f388b804 100644 --- a/internal/controller/supervisorstorage/garbage_collector.go +++ b/internal/controller/supervisorstorage/garbage_collector.go @@ -243,7 +243,7 @@ func (c *garbageCollectorController) revokeUpstreamOIDCRefreshToken(ctx context. } // Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint. - err := foundOIDCIdentityProviderI.RevokeRefreshToken(ctx, customSessionData.OIDC.UpstreamRefreshToken) + err := foundOIDCIdentityProviderI.RevokeToken(ctx, customSessionData.OIDC.UpstreamRefreshToken, provider.RefreshTokenType) if err != nil { // This could be a network failure, a 503 result which we should retry // (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1), diff --git a/internal/controller/supervisorstorage/garbage_collector_test.go b/internal/controller/supervisorstorage/garbage_collector_test.go index 042c4d38f..d442c4749 100644 --- a/internal/controller/supervisorstorage/garbage_collector_test.go +++ b/internal/controller/supervisorstorage/garbage_collector_test.go @@ -366,18 +366,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // The upstream refresh token is only revoked for the active authcode session. - idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, + idpListerBuilder.RequireExactlyOneCallToRevokeToken(t, "upstream-oidc-provider-name", - &oidctestutil.RevokeRefreshTokenArgs{ - Ctx: syncContext.Context, - RefreshToken: "fake-upstream-refresh-token", + &oidctestutil.RevokeTokenArgs{ + Ctx: syncContext.Context, + Token: "fake-upstream-refresh-token", + TokenType: provider.RefreshTokenType, }, ) @@ -448,14 +449,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // Nothing to revoke since we couldn't read the invalid secret. - idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) + idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t) // The invalid authcode session secrets is still deleted because it is expired. r.ElementsMatch( @@ -524,14 +525,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // Nothing to revoke since we couldn't find the upstream in the cache. - idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) + idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t) // The authcode session secrets is still deleted because it is expired. r.ElementsMatch( @@ -600,14 +601,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // Nothing to revoke since we couldn't find the upstream in the cache. - idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) + idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t) // The authcode session secrets is still deleted because it is expired. r.ElementsMatch( @@ -677,18 +678,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail + WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // Tried to revoke it, although this revocation will fail. - idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, + idpListerBuilder.RequireExactlyOneCallToRevokeToken(t, "upstream-oidc-provider-name", - &oidctestutil.RevokeRefreshTokenArgs{ - Ctx: syncContext.Context, - RefreshToken: "fake-upstream-refresh-token", + &oidctestutil.RevokeTokenArgs{ + Ctx: syncContext.Context, + Token: "fake-upstream-refresh-token", + TokenType: provider.RefreshTokenType, }, ) @@ -749,18 +751,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail + WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // Tried to revoke it, although this revocation will fail. - idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, + idpListerBuilder.RequireExactlyOneCallToRevokeToken(t, "upstream-oidc-provider-name", - &oidctestutil.RevokeRefreshTokenArgs{ - Ctx: syncContext.Context, - RefreshToken: "fake-upstream-refresh-token", + &oidctestutil.RevokeTokenArgs{ + Ctx: syncContext.Context, + Token: "fake-upstream-refresh-token", + TokenType: provider.RefreshTokenType, }, ) @@ -875,18 +878,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // The upstream refresh token is only revoked for the downstream session which had offline_access granted. - idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, + idpListerBuilder.RequireExactlyOneCallToRevokeToken(t, "upstream-oidc-provider-name", - &oidctestutil.RevokeRefreshTokenArgs{ - Ctx: syncContext.Context, - RefreshToken: "fake-upstream-refresh-token", + &oidctestutil.RevokeTokenArgs{ + Ctx: syncContext.Context, + Token: "fake-upstream-refresh-token", + TokenType: provider.RefreshTokenType, }, ) @@ -958,18 +962,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) { happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName("upstream-oidc-provider-name"). WithResourceUID("upstream-oidc-provider-uid"). - WithRevokeRefreshTokenError(nil) + WithRevokeTokenError(nil) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) startInformersAndController(idpListerBuilder.Build()) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) // The upstream refresh token is revoked. - idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, + idpListerBuilder.RequireExactlyOneCallToRevokeToken(t, "upstream-oidc-provider-name", - &oidctestutil.RevokeRefreshTokenArgs{ - Ctx: syncContext.Context, - RefreshToken: "fake-upstream-refresh-token", + &oidctestutil.RevokeTokenArgs{ + Ctx: syncContext.Context, + Token: "fake-upstream-refresh-token", + TokenType: provider.RefreshTokenType, }, ) @@ -1015,7 +1020,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) { r.False(syncContext.Queue.(*testQueue).called) // Run sync again when not enough time has passed since the most recent run, so no delete - // operations should happen even though there is a expired secret now. + // operations should happen even though there is an expired secret now. fakeClock.Step(29 * time.Second) r.NoError(controllerlib.TestSync(t, subject, *syncContext)) require.Empty(t, kubeClient.Actions()) diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index 046f18494..2d0dbed15 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -14,6 +14,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + provider "go.pinniped.dev/internal/oidc/provider" nonce "go.pinniped.dev/pkg/oidcclient/nonce" oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" pkce "go.pinniped.dev/pkg/oidcclient/pkce" @@ -215,18 +216,18 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, ar return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1) } -// RevokeRefreshToken mocks base method. -func (m *MockUpstreamOIDCIdentityProviderI) RevokeRefreshToken(arg0 context.Context, arg1 string) error { +// RevokeToken mocks base method. +func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 provider.RevocableTokenType) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RevokeRefreshToken", arg0, arg1) + ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// RevokeRefreshToken indicates an expected call of RevokeRefreshToken. -func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0, arg1 interface{}) *gomock.Call { +// RevokeToken indicates an expected call of RevokeToken. +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeToken(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeRefreshToken), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeToken), arg0, arg1, arg2) } // ValidateToken mocks base method. diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index a054a4c84..e5be51ff8 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -17,6 +17,14 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) +type RevocableTokenType string + +// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 +const ( + RefreshTokenType RevocableTokenType = "refresh_token" + AccessTokenType RevocableTokenType = "access_token" +) + type UpstreamOIDCIdentityProviderI interface { // GetName returns a name for this upstream provider, which will be used as a component of the path for the // callback endpoint hosted by the Supervisor. @@ -68,8 +76,8 @@ type UpstreamOIDCIdentityProviderI interface { // validate the ID token. PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) - // RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint. - RevokeRefreshToken(ctx context.Context, refreshToken string) error + // RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint. + RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 03603bdbb..6d0cabcab 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -68,11 +68,12 @@ type PerformRefreshArgs struct { ExpectedSubject string } -// RevokeRefreshTokenArgs is used to spy on calls to -// TestUpstreamOIDCIdentityProvider.RevokeRefreshTokenArgsFunc(). -type RevokeRefreshTokenArgs struct { - Ctx context.Context - RefreshToken string +// RevokeTokenArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.RevokeTokenArgsFunc(). +type RevokeTokenArgs struct { + Ctx context.Context + Token string + TokenType provider.RevocableTokenType } // ValidateTokenArgs is used to spy on calls to @@ -166,7 +167,7 @@ type TestUpstreamOIDCIdentityProvider struct { PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) - RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error + RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) @@ -176,8 +177,8 @@ type TestUpstreamOIDCIdentityProvider struct { passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs performRefreshCallCount int performRefreshArgs []*PerformRefreshArgs - revokeRefreshTokenCallCount int - revokeRefreshTokenArgs []*RevokeRefreshTokenArgs + revokeTokenCallCount int + revokeTokenArgs []*RevokeTokenArgs validateTokenCallCount int validateTokenArgs []*ValidateTokenArgs } @@ -278,16 +279,17 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r return u.PerformRefreshFunc(ctx, refreshToken) } -func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error { - if u.revokeRefreshTokenArgs == nil { - u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0) +func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error { + if u.revokeTokenArgs == nil { + u.revokeTokenArgs = make([]*RevokeTokenArgs, 0) } - u.revokeRefreshTokenCallCount++ - u.revokeRefreshTokenArgs = append(u.revokeRefreshTokenArgs, &RevokeRefreshTokenArgs{ - Ctx: ctx, - RefreshToken: refreshToken, + u.revokeTokenCallCount++ + u.revokeTokenArgs = append(u.revokeTokenArgs, &RevokeTokenArgs{ + Ctx: ctx, + Token: token, + TokenType: tokenType, }) - return u.RevokeRefreshTokenFunc(ctx, refreshToken) + return u.RevokeTokenFunc(ctx, token, tokenType) } func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int { @@ -301,15 +303,15 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *Perform return u.performRefreshArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenCallCount() int { +func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenCallCount() int { return u.performRefreshCallCount } -func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *RevokeRefreshTokenArgs { - if u.revokeRefreshTokenArgs == nil { - u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0) +func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenArgs(call int) *RevokeTokenArgs { + if u.revokeTokenArgs == nil { + u.revokeTokenArgs = make([]*RevokeTokenArgs, 0) } - return u.revokeRefreshTokenArgs[call] + return u.revokeTokenArgs[call] } func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { @@ -552,40 +554,40 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes ) } -func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeRefreshToken( +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeToken( t *testing.T, expectedPerformedByUpstreamName string, - expectedArgs *RevokeRefreshTokenArgs, + expectedArgs *RevokeTokenArgs, ) { t.Helper() - var actualArgs *RevokeRefreshTokenArgs + var actualArgs *RevokeTokenArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - callCountOnThisUpstream := upstreamOIDC.revokeRefreshTokenCallCount + callCountOnThisUpstream := upstreamOIDC.revokeTokenCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name - actualArgs = upstreamOIDC.revokeRefreshTokenArgs[0] + actualArgs = upstreamOIDC.revokeTokenArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, - "should have been exactly one call to RevokeRefreshToken() by all OIDC upstreams", + "should have been exactly one call to RevokeToken() by all OIDC upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, - "RevokeRefreshToken() was called on the wrong OIDC upstream", + "RevokeToken() was called on the wrong OIDC upstream", ) require.Equal(t, expectedArgs, actualArgs) } -func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeRefreshToken(t *testing.T) { +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeToken(t *testing.T) { t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeRefreshTokenCallCount + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeTokenCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, - "expected exactly zero calls to RevokeRefreshToken()", + "expected exactly zero calls to RevokeToken()", ) } @@ -610,7 +612,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct { authcodeExchangeErr error passwordGrantErr error performRefreshErr error - revokeRefreshTokenErr error + revokeTokenErr error validateTokenErr error } @@ -727,8 +729,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err err return u } -func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeRefreshTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { - u.revokeRefreshTokenErr = err +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.revokeTokenErr = err return u } @@ -761,8 +763,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent } return u.refreshedTokens, nil }, - RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error { - return u.revokeRefreshTokenErr + RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error { + return u.revokeTokenErr }, ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { if u.validateTokenErr != nil { diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 437eb6efc..310a6df4e 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -137,32 +137,36 @@ func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token() } -// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint. -func (p *ProviderConfig) RevokeRefreshToken(ctx context.Context, refreshToken string) error { +// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint. +func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error { if p.RevocationURL == nil { - plog.Trace("RevokeRefreshToken() was called but upstream provider has no available revocation endpoint", "providerName", p.Name) + plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint", + "providerName", p.Name, + "tokenType", tokenType, + ) return nil } // First try using client auth in the request params. - tryAnotherClientAuthMethod, err := p.tryRevokeRefreshToken(ctx, refreshToken, false) + tryAnotherClientAuthMethod, err := p.tryRevokeToken(ctx, token, tokenType, false) if tryAnotherClientAuthMethod { // Try again using basic auth this time. Overwrite the first client auth error, // which isn't useful anymore when retrying. - _, err = p.tryRevokeRefreshToken(ctx, refreshToken, true) + _, err = p.tryRevokeToken(ctx, token, tokenType, true) } return err } -// tryRevokeRefreshToken will call the revocation endpoint using either basic auth or by including +// tryRevokeToken will call the revocation endpoint using either basic auth or by including // client auth in the request params. It will return an error when the request failed. If the // request failed for a reason that might be due to bad client auth, then it will return true // for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying // again using the other client auth method. // RFC 7009 defines how to make a revocation request and how to interpret the response. // See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details. -func (p *ProviderConfig) tryRevokeRefreshToken( +func (p *ProviderConfig) tryRevokeToken( ctx context.Context, - refreshToken string, + token string, + tokenType provider.RevocableTokenType, useBasicAuth bool, ) (tryAnotherClientAuthMethod bool, err error) { clientID := p.Config.ClientID @@ -171,8 +175,8 @@ func (p *ProviderConfig) tryRevokeRefreshToken( httpClient := p.Client params := url.Values{ - "token": []string{refreshToken}, - "token_type_hint": []string{"refresh_token"}, + "token": []string{token}, + "token_type_hint": []string{string(tokenType)}, } if !useBasicAuth { params["client_id"] = []string{clientID} @@ -200,11 +204,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken( switch resp.StatusCode { case http.StatusOK: // Success! - plog.Trace("RevokeRefreshToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) + plog.Trace("RevokeToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) return false, nil case http.StatusBadRequest: // Bad request might be due to bad client auth method. Try to detect that. - plog.Trace("RevokeRefreshToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) + plog.Trace("RevokeToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) body, err := io.ReadAll(resp.Body) if err != nil { return false, @@ -227,11 +231,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken( } // Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again // using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 - plog.Trace("RevokeRefreshToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth) + plog.Trace("RevokeToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth) return true, err default: // Any other error is probably not due to failed client auth. - plog.Trace("RevokeRefreshToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode) + plog.Trace("RevokeToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode) return false, fmt.Errorf("server responded with status %d", resp.StatusCode) } } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index f89065620..ecef00a5e 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -24,6 +24,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" @@ -455,73 +456,114 @@ func TestProviderConfig(t *testing.T) { } }) - t.Run("RevokeRefreshToken", func(t *testing.T) { + t.Run("RevokeToken", func(t *testing.T) { tests := []struct { - name string - nilRevocationURL bool - statusCodes []int - returnErrBodies []string - wantErr string - wantNumRequests int + name string + tokenType provider.RevocableTokenType + nilRevocationURL bool + statusCodes []int + returnErrBodies []string + wantErr string + wantNumRequests int + wantTokenTypeHint string }{ { - name: "success without calling the server when there is no revocation URL set", + name: "success without calling the server when there is no revocation URL set for refresh token", + tokenType: provider.RefreshTokenType, nilRevocationURL: true, wantNumRequests: 0, }, { - name: "success when the server returns 200 OK on the first call", - statusCodes: []int{http.StatusOK}, - wantNumRequests: 1, + name: "success without calling the server when there is no revocation URL set for access token", + tokenType: provider.AccessTokenType, + nilRevocationURL: true, + wantNumRequests: 0, }, { - name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call", + name: "success when the server returns 200 OK on the first call for refresh token", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusOK}, + wantNumRequests: 1, + wantTokenTypeHint: "refresh_token", + }, + { + name: "success when the server returns 200 OK on the first call for access token", + tokenType: provider.AccessTokenType, + statusCodes: []int{http.StatusOK}, + wantNumRequests: 1, + wantTokenTypeHint: "access_token", + }, + { + name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token", + tokenType: provider.RefreshTokenType, statusCodes: []int{http.StatusBadRequest, http.StatusOK}, // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure - returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, - wantNumRequests: 2, + returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, + wantNumRequests: 2, + wantTokenTypeHint: "refresh_token", }, { - name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call", - statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest}, - returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`}, - wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`, - wantNumRequests: 2, + name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token", + tokenType: provider.AccessTokenType, + statusCodes: []int{http.StatusBadRequest, http.StatusOK}, + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure + returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, + wantNumRequests: 2, + wantTokenTypeHint: "access_token", }, { - name: "error when the server returns 400 Bad Request with bad JSON body on the first call", - statusCodes: []int{http.StatusBadRequest}, - returnErrBodies: []string{`invalid JSON body`}, - wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`, - wantNumRequests: 1, + name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest}, + returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`}, + wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`, + wantNumRequests: 2, + wantTokenTypeHint: "refresh_token", }, { - name: "error when the server returns 400 Bad Request with empty body", - statusCodes: []int{http.StatusBadRequest}, - returnErrBodies: []string{``}, - wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`, - wantNumRequests: 1, + name: "error when the server returns 400 Bad Request with bad JSON body on the first call", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusBadRequest}, + returnErrBodies: []string{`invalid JSON body`}, + wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`, + wantNumRequests: 1, + wantTokenTypeHint: "refresh_token", }, { - name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call", - statusCodes: []int{http.StatusBadRequest, http.StatusForbidden}, - returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, - wantErr: "server responded with status 403", - wantNumRequests: 2, + name: "error when the server returns 400 Bad Request with empty body", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusBadRequest}, + returnErrBodies: []string{``}, + wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`, + wantNumRequests: 1, + wantTokenTypeHint: "refresh_token", }, { - name: "error when server returns any other 400 error on first call", - statusCodes: []int{http.StatusBadRequest}, - returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`}, - wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`, - wantNumRequests: 1, + name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusBadRequest, http.StatusForbidden}, + returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, + wantErr: "server responded with status 403", + wantNumRequests: 2, + wantTokenTypeHint: "refresh_token", }, { - name: "error when server returns any other error aside from 400 on first call", - statusCodes: []int{http.StatusForbidden}, - returnErrBodies: []string{""}, - wantErr: "server responded with status 403", - wantNumRequests: 1, + name: "error when server returns any other 400 error on first call", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusBadRequest}, + returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`}, + wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`, + wantNumRequests: 1, + wantTokenTypeHint: "refresh_token", + }, + { + name: "error when server returns any other error aside from 400 on first call", + tokenType: provider.RefreshTokenType, + statusCodes: []int{http.StatusForbidden}, + returnErrBodies: []string{""}, + wantErr: "server responded with status 403", + wantNumRequests: 1, + wantTokenTypeHint: "refresh_token", }, } for _, tt := range tests { @@ -536,15 +578,15 @@ func TestProviderConfig(t *testing.T) { if numRequests == 1 { // First request should use client_id/client_secret params. require.Equal(t, 4, len(r.Form)) + require.Equal(t, "test-upstream-token", r.Form.Get("token")) + require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint")) require.Equal(t, "test-client-id", r.Form.Get("client_id")) require.Equal(t, "test-client-secret", r.Form.Get("client_secret")) - require.Equal(t, "refresh_token", r.Form.Get("token_type_hint")) - require.Equal(t, "test-initial-refresh-token", r.Form.Get("token")) } else { // Second request, if there is one, should use basic auth. require.Equal(t, 2, len(r.Form)) - require.Equal(t, "refresh_token", r.Form.Get("token_type_hint")) - require.Equal(t, "test-initial-refresh-token", r.Form.Get("token")) + require.Equal(t, "test-upstream-token", r.Form.Get("token")) + require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint")) username, password, hasBasicAuth := r.BasicAuth() require.True(t, hasBasicAuth, "request should have had basic auth but did not") require.Equal(t, "test-client-id", username) @@ -574,9 +616,10 @@ func TestProviderConfig(t *testing.T) { p.RevocationURL = nil } - err = p.RevokeRefreshToken( + err = p.RevokeToken( context.Background(), - "test-initial-refresh-token", + "test-upstream-token", + tt.tokenType, ) require.Equal(t, tt.wantNumRequests, numRequests,