diff --git a/pkg/auth/idp/oauth2/provider.go b/pkg/auth/idp/oauth2/provider.go index 3eecc904f..40b61c2f5 100644 --- a/pkg/auth/idp/oauth2/provider.go +++ b/pkg/auth/idp/oauth2/provider.go @@ -91,9 +91,10 @@ type Provider struct { // often available via site-specific packages, such as // google.Endpoint or github.Endpoint. // - Scopes specifies optional requested permissions. - ClientID string - oauth2Config Configuration - oidcProvider *oidc.Provider + ClientID string + oauth2Config Configuration + oidcProvider *oidc.Provider + provHTTPClient *http.Client } // derivedKey is the key used to compute the HMAC for signing the oauth state parameter @@ -103,8 +104,9 @@ var derivedKey = pbkdf2.Key([]byte(getPassphraseForIdpHmac()), []byte(getSaltFor // NewOauth2ProviderClient instantiates a new oauth2 client using the configured credentials // it returns a *Provider object that contains the necessary configuration to initiate an // oauth2 authentication flow -func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, error) { - provider, err := oidc.NewProvider(ctx, GetIdpURL()) +func NewOauth2ProviderClient(ctx context.Context, scopes []string, httpClient *http.Client) (*Provider, error) { + customCtx := oidc.ClientContext(ctx, httpClient) + provider, err := oidc.NewProvider(customCtx, GetIdpURL()) if err != nil { return nil, err } @@ -122,6 +124,7 @@ func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, e } client.oidcProvider = provider client.ClientID = GetIdpClientID() + client.provHTTPClient = httpClient return client, nil } @@ -172,10 +175,11 @@ func (client *Provider) VerifyIdentity(ctx context.Context, code, state string) }, nil } stsEndpoint := GetSTSEndpoint() - sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry) - if err != nil { - return nil, err - } + sts := credentials.New(&credentials.STSWebIdentity{ + Client: client.provHTTPClient, + STSEndpoint: stsEndpoint, + GetWebIDTokenExpiry: getWebTokenExpiry, + }) return sts, nil } diff --git a/restapi/user_login.go b/restapi/user_login.go index e0f6a0276..7366aa72e 100644 --- a/restapi/user_login.go +++ b/restapi/user_login.go @@ -187,7 +187,7 @@ func getLoginDetailsResponse() (*models.LoginDetails, *models.Error) { if oauth2.IsIdpEnabled() { loginStrategy = models.LoginDetailsLoginStrategyRedirect // initialize new oauth2 client - oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil) + oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient()) if err != nil { return nil, prepareError(err) } @@ -235,7 +235,7 @@ func getLoginOauth2AuthResponse(lr *models.LoginOauth2AuthRequest) (*models.Logi return loginResponse, nil } else if oauth2.IsIdpEnabled() { // initialize new oauth2 client - oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil) + oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient()) if err != nil { return nil, prepareError(err) }