Inline using phttp.Default

This commit is contained in:
Joshua Casey
2024-04-23 18:06:26 -05:00
parent d0bbfca831
commit c8b90df6f1
2 changed files with 30 additions and 48 deletions

View File

@@ -36,6 +36,7 @@ import (
"go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamgithub"
)
@@ -68,7 +69,6 @@ type gitHubWatcherController struct {
client supervisorclientset.Interface
gitHubIdentityProviderInformer idpinformers.GitHubIdentityProviderInformer
secretInformer corev1informers.SecretInformer
httpClientBuilder func(rootCAs *x509.CertPool) *http.Client
clock clock.Clock
}
@@ -82,7 +82,6 @@ func New(
log plog.Logger,
withInformer pinnipedcontroller.WithInformerOptionFunc,
clock clock.Clock,
httpClientBuilder func(rootCAs *x509.CertPool) *http.Client,
) controllerlib.Controller {
c := gitHubWatcherController{
namespace: namespace,
@@ -91,7 +90,6 @@ func New(
log: log.WithName(controllerName),
gitHubIdentityProviderInformer: gitHubIdentityProviderInformer,
secretInformer: secretInformer,
httpClientBuilder: httpClientBuilder,
clock: clock,
}
@@ -390,7 +388,7 @@ func (c *gitHubWatcherController) validateGitHubConnection(
Status: metav1.ConditionTrue,
Reason: upstreamwatchers.ReasonSuccess,
Message: fmt.Sprintf("spec.githubAPI.host (%q) is reachable and TLS verification succeeds", hostPort.Endpoint()),
}, fmt.Sprintf("https://%s", hostPort.Endpoint()), c.httpClientBuilder(certPool), conn.Close()
}, fmt.Sprintf("https://%s", hostPort.Endpoint()), phttp.Default(certPool), conn.Close()
}
// buildDialErrorMessage standardizes DNS error messages that appear differently on different platforms, so that tests and log grepping is uniform.

View File

@@ -6,7 +6,6 @@ package githubupstreamwatcher
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
@@ -22,10 +21,10 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
utilnet "k8s.io/apimachinery/pkg/util/net"
k8sinformers "k8s.io/client-go/informers"
kubernetesfake "k8s.io/client-go/kubernetes/fake"
coretesting "k8s.io/client-go/testing"
"k8s.io/client-go/util/cert"
"k8s.io/utils/clock"
clocktesting "k8s.io/utils/clock/testing"
"k8s.io/utils/ptr"
@@ -34,9 +33,9 @@ import (
pinnipedfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions"
"go.pinniped.dev/internal/certauthority"
pinnipedcontroller "go.pinniped.dev/internal/controller"
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider"
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
"go.pinniped.dev/internal/net/phttp"
@@ -365,7 +364,7 @@ func TestController(t *testing.T) {
AllowedOrganizations: []string{"organization1", "org2"},
OrganizationLoginPolicy: "OnlyUsersFromAllowedOrganizations",
AuthorizationURL: fmt.Sprintf("https://%s/login/oauth/authorize", *validFilledOutIDP.Spec.GitHubAPI.Host),
HttpClient: buildPretendHttpClient(t, goodServerCA),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -412,7 +411,7 @@ func TestController(t *testing.T) {
},
OrganizationLoginPolicy: "AllGitHubUsers",
AuthorizationURL: fmt.Sprintf("https://%s/login/oauth/authorize", goodServerDomain),
HttpClient: buildPretendHttpClient(t, goodServerCA),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -466,7 +465,7 @@ func TestController(t *testing.T) {
},
OrganizationLoginPolicy: "AllGitHubUsers",
AuthorizationURL: fmt.Sprintf("https://%s/login/oauth/authorize", goodServerIPv6Domain),
HttpClient: buildPretendHttpClient(t, goodServerCA),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -546,7 +545,7 @@ func TestController(t *testing.T) {
AllowedOrganizations: []string{"organization1", "org2"},
OrganizationLoginPolicy: "OnlyUsersFromAllowedOrganizations",
AuthorizationURL: fmt.Sprintf("https://%s/login/oauth/authorize", *validFilledOutIDP.Spec.GitHubAPI.Host),
HttpClient: buildPretendHttpClient(t, goodServerCA),
HttpClient: nil, // let the test runner populate this for us
},
{
Name: "other-idp-name",
@@ -561,7 +560,7 @@ func TestController(t *testing.T) {
AllowedOrganizations: []string{"organization1", "org2"},
OrganizationLoginPolicy: "OnlyUsersFromAllowedOrganizations",
AuthorizationURL: fmt.Sprintf("https://%s/login/oauth/authorize", *validFilledOutIDP.Spec.GitHubAPI.Host),
HttpClient: buildPretendHttpClient(t, goodServerCA),
HttpClient: nil, // let the test runner populate this for us
},
},
wantResultingUpstreams: []v1alpha1.GitHubIdentityProvider{
@@ -1609,7 +1608,6 @@ func TestController(t *testing.T) {
logger,
controllerlib.WithInformer,
frozenClock,
buildHttpClient,
)
ctx, cancel := context.WithCancel(context.Background())
@@ -1653,7 +1651,14 @@ func TestController(t *testing.T) {
require.Equal(t, tt.wantResultingCache[i].AllowedOrganizations, actualIDP.GetAllowedOrganizations())
require.Equal(t, tt.wantResultingCache[i].OrganizationLoginPolicy, actualIDP.GetOrganizationLoginPolicy())
require.Equal(t, tt.wantResultingCache[i].AuthorizationURL, actualIDP.GetAuthorizationURL())
compareTLSClientConfigWithinHttpClients(t, tt.wantResultingCache[i].HttpClient, actualIDP.GetHttpClient())
require.GreaterOrEqual(t, len(tt.githubIdentityProviders), i+1, "there must be at least as many input identity providers as items in the cache")
githubIDP, ok := tt.githubIdentityProviders[i].(*v1alpha1.GitHubIdentityProvider)
require.True(t, ok)
certPool, _, err := pinnipedcontroller.BuildCertPoolIDP(githubIDP.Spec.GitHubAPI.TLS)
require.NoError(t, err)
compareTLSClientConfigWithinHttpClients(t, phttp.Default(certPool), actualIDP.GetHttpClient())
require.Equal(t, tt.wantResultingCache[i].OAuth2Config, actualIDP.OAuth2Config)
}
@@ -1857,7 +1862,6 @@ func TestController_WithExistingConditions(t *testing.T) {
logger,
controllerlib.WithInformer,
frozenClock,
buildHttpClient,
)
ctx, cancel := context.WithCancel(context.Background())
@@ -1877,41 +1881,23 @@ func TestController_WithExistingConditions(t *testing.T) {
}
}
func buildPretendHttpClient(t *testing.T, ca []byte) *http.Client {
func compareTLSClientConfigWithinHttpClients(t *testing.T, expected *http.Client, actual *http.Client) {
t.Helper()
rootCAs, err := cert.NewPoolFromBytes(ca)
require.NotEmpty(t, expected)
require.NotEmpty(t, actual)
require.Equal(t, expected.Timeout, actual.Timeout)
expectedConfig, err := utilnet.TLSClientConfig(expected.Transport)
require.NoError(t, err)
return buildHttpClient(rootCAs)
}
func buildHttpClient(rootCAs *x509.CertPool) *http.Client {
baseRT := http.DefaultTransport.(*http.Transport).Clone()
baseRT.TLSClientConfig = ptls.Default(rootCAs)
actualConfig, err := utilnet.TLSClientConfig(actual.Transport)
require.NoError(t, err)
return &http.Client{
Transport: baseRT,
}
}
func compareTLSClientConfigWithinHttpClients(t *testing.T, c1 *http.Client, c2 *http.Client) {
t.Helper()
if c1 == nil {
require.Nil(t, c2)
return
}
t1, ok := c1.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, t1)
require.NotNil(t, t1.TLSClientConfig)
t2, ok := c2.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, t2)
require.NotNil(t, t2.TLSClientConfig)
require.Equal(t, t1.TLSClientConfig.ClientCAs, t2.TLSClientConfig.ClientCAs)
require.True(t, actualConfig.RootCAs.Equal(expectedConfig.RootCAs))
actualConfig.RootCAs = expectedConfig.RootCAs
require.Equal(t, expectedConfig, actualConfig)
}
func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) {
@@ -1983,7 +1969,6 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) {
logger,
observableInformers.WithInformer,
clock.RealClock{},
phttp.Default,
)
unrelated := &corev1.Secret{}
@@ -2053,7 +2038,6 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) {
logger,
observableInformers.WithInformer,
clock.RealClock{},
phttp.Default,
)
unrelated := &v1alpha1.GitHubIdentityProvider{}