Custom HTTP Client TLS transport for STSWebIdentity (#612)

Co-authored-by: Daniel Valdivia <hola@danielvaldivia.com>
This commit is contained in:
Lenin Alevski
2021-02-25 09:09:55 -08:00
committed by GitHub
parent 6ac95e40a4
commit 9c1f0c47b0
2 changed files with 15 additions and 11 deletions

View File

@@ -91,9 +91,10 @@ type Provider struct {
// often available via site-specific packages, such as // often available via site-specific packages, such as
// google.Endpoint or github.Endpoint. // google.Endpoint or github.Endpoint.
// - Scopes specifies optional requested permissions. // - Scopes specifies optional requested permissions.
ClientID string ClientID string
oauth2Config Configuration oauth2Config Configuration
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
provHTTPClient *http.Client
} }
// derivedKey is the key used to compute the HMAC for signing the oauth state parameter // derivedKey is the key used to compute the HMAC for signing the oauth state parameter
@@ -103,8 +104,9 @@ var derivedKey = pbkdf2.Key([]byte(getPassphraseForIdpHmac()), []byte(getSaltFor
// NewOauth2ProviderClient instantiates a new oauth2 client using the configured credentials // NewOauth2ProviderClient instantiates a new oauth2 client using the configured credentials
// it returns a *Provider object that contains the necessary configuration to initiate an // it returns a *Provider object that contains the necessary configuration to initiate an
// oauth2 authentication flow // oauth2 authentication flow
func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, error) { func NewOauth2ProviderClient(ctx context.Context, scopes []string, httpClient *http.Client) (*Provider, error) {
provider, err := oidc.NewProvider(ctx, GetIdpURL()) customCtx := oidc.ClientContext(ctx, httpClient)
provider, err := oidc.NewProvider(customCtx, GetIdpURL())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -122,6 +124,7 @@ func NewOauth2ProviderClient(ctx context.Context, scopes []string) (*Provider, e
} }
client.oidcProvider = provider client.oidcProvider = provider
client.ClientID = GetIdpClientID() client.ClientID = GetIdpClientID()
client.provHTTPClient = httpClient
return client, nil return client, nil
} }
@@ -172,10 +175,11 @@ func (client *Provider) VerifyIdentity(ctx context.Context, code, state string)
}, nil }, nil
} }
stsEndpoint := GetSTSEndpoint() stsEndpoint := GetSTSEndpoint()
sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry) sts := credentials.New(&credentials.STSWebIdentity{
if err != nil { Client: client.provHTTPClient,
return nil, err STSEndpoint: stsEndpoint,
} GetWebIDTokenExpiry: getWebTokenExpiry,
})
return sts, nil return sts, nil
} }

View File

@@ -187,7 +187,7 @@ func getLoginDetailsResponse() (*models.LoginDetails, *models.Error) {
if oauth2.IsIdpEnabled() { if oauth2.IsIdpEnabled() {
loginStrategy = models.LoginDetailsLoginStrategyRedirect loginStrategy = models.LoginDetailsLoginStrategyRedirect
// initialize new oauth2 client // initialize new oauth2 client
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil) oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient())
if err != nil { if err != nil {
return nil, prepareError(err) return nil, prepareError(err)
} }
@@ -235,7 +235,7 @@ func getLoginOauth2AuthResponse(lr *models.LoginOauth2AuthRequest) (*models.Logi
return loginResponse, nil return loginResponse, nil
} else if oauth2.IsIdpEnabled() { } else if oauth2.IsIdpEnabled() {
// initialize new oauth2 client // initialize new oauth2 client
oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil) oauth2Client, err := oauth2.NewOauth2ProviderClient(ctx, nil, GetConsoleSTSClient())
if err != nil { if err != nil {
return nil, prepareError(err) return nil, prepareError(err)
} }