mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-07 05:57:02 +00:00
Merge pull request #2056 from vmware-tanzu/jtc/tls-dial-should-have-timeout
GitHubIdentityProvider and WebhookAuthenticator should perform `tls.Dial` with a timeout
This commit is contained in:
@@ -6,7 +6,6 @@ package webhookcachefiller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -77,6 +76,7 @@ func New(
|
||||
withInformer pinnipedcontroller.WithInformerOptionFunc,
|
||||
clock clock.Clock,
|
||||
log plog.Logger,
|
||||
dialer ptls.Dialer,
|
||||
) controllerlib.Controller {
|
||||
return controllerlib.New(
|
||||
controllerlib.Config{
|
||||
@@ -90,6 +90,7 @@ func New(
|
||||
configMapInformer: configMapInformer,
|
||||
clock: clock,
|
||||
log: log.WithName(controllerName),
|
||||
dialer: dialer,
|
||||
},
|
||||
},
|
||||
withInformer(
|
||||
@@ -125,6 +126,7 @@ type webhookCacheFillerController struct {
|
||||
client conciergeclientset.Interface
|
||||
clock clock.Clock
|
||||
log plog.Logger
|
||||
dialer ptls.Dialer
|
||||
}
|
||||
|
||||
// Sync implements controllerlib.Syncer.
|
||||
@@ -197,7 +199,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co
|
||||
)
|
||||
} else {
|
||||
// Run all remaining validations.
|
||||
a, moreConditions, moreErrs := c.doExpensiveValidations(webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger)
|
||||
a, moreConditions, moreErrs := c.doExpensiveValidations(ctx, webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger)
|
||||
newWebhookAuthenticatorForCache = a
|
||||
conditions = append(conditions, moreConditions...)
|
||||
errs = append(errs, moreErrs...)
|
||||
@@ -236,6 +238,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co
|
||||
}
|
||||
|
||||
func (c *webhookCacheFillerController) doExpensiveValidations(
|
||||
ctx context.Context,
|
||||
webhookAuthenticator *authenticationv1alpha1.WebhookAuthenticator,
|
||||
endpointHostPort *endpointaddr.HostPort,
|
||||
caBundle *tlsconfigutil.CABundle,
|
||||
@@ -246,7 +249,7 @@ func (c *webhookCacheFillerController) doExpensiveValidations(
|
||||
var conditions []*metav1.Condition
|
||||
var errs []error
|
||||
|
||||
conditions, tlsNegotiateErr := c.validateConnection(caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger)
|
||||
conditions, tlsNegotiateErr := c.validateConnection(ctx, caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger)
|
||||
errs = append(errs, tlsNegotiateErr)
|
||||
okSoFar = okSoFar && tlsNegotiateErr == nil
|
||||
|
||||
@@ -412,6 +415,7 @@ func successfulWebhookConnectionValidCondition() *metav1.Condition {
|
||||
}
|
||||
|
||||
func (c *webhookCacheFillerController) validateConnection(
|
||||
ctx context.Context,
|
||||
certPool *x509.CertPool,
|
||||
endpointHostPort *endpointaddr.HostPort,
|
||||
conditions []*metav1.Condition,
|
||||
@@ -428,11 +432,13 @@ func (c *webhookCacheFillerController) validateConnection(
|
||||
return conditions, nil
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", endpointHostPort.Endpoint(), ptls.Default(certPool))
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
err := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, endpointHostPort.Endpoint(), certPool, logger)
|
||||
|
||||
if err != nil {
|
||||
errText := "cannot dial server"
|
||||
msg := fmt.Sprintf("%s: %s", errText, err.Error())
|
||||
msg := fmt.Sprintf("%s: %s", errText, err)
|
||||
conditions = append(conditions, &metav1.Condition{
|
||||
Type: typeWebhookConnectionValid,
|
||||
Status: metav1.ConditionFalse,
|
||||
@@ -442,13 +448,6 @@ func (c *webhookCacheFillerController) validateConnection(
|
||||
return conditions, fmt.Errorf("%s: %w", errText, err)
|
||||
}
|
||||
|
||||
// this error should never be significant
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
// no unit test for this failure
|
||||
logger.Error("error closing dialer", err)
|
||||
}
|
||||
|
||||
conditions = append(conditions, successfulWebhookConnectionValidCondition())
|
||||
return conditions, nil
|
||||
}
|
||||
|
||||
@@ -1934,7 +1934,8 @@ func TestController(t *testing.T) {
|
||||
kubeInformers.Core().V1().ConfigMaps(),
|
||||
controllerlib.WithInformer,
|
||||
frozenClock,
|
||||
logger)
|
||||
logger,
|
||||
ptls.NewDialer())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -2177,7 +2178,8 @@ func TestControllerFilterSecret(t *testing.T) {
|
||||
configMapInformer,
|
||||
observableInformers.WithInformer,
|
||||
frozenClock,
|
||||
logger)
|
||||
logger,
|
||||
ptls.NewDialer())
|
||||
|
||||
unrelated := &corev1.Secret{}
|
||||
filter := observableInformers.GetFilterForInformer(secretInformer)
|
||||
@@ -2238,7 +2240,8 @@ func TestControllerFilterConfigMap(t *testing.T) {
|
||||
configMapInformer,
|
||||
observableInformers.WithInformer,
|
||||
frozenClock,
|
||||
logger)
|
||||
logger,
|
||||
ptls.NewDialer())
|
||||
|
||||
unrelated := &corev1.ConfigMap{}
|
||||
filter := observableInformers.GetFilterForInformer(configMapInformer)
|
||||
|
||||
@@ -6,7 +6,6 @@ package githubupstreamwatcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -117,7 +116,7 @@ type gitHubWatcherController struct {
|
||||
secretInformer corev1informers.SecretInformer
|
||||
configMapInformer corev1informers.ConfigMapInformer
|
||||
clock clock.Clock
|
||||
dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
dialer ptls.Dialer
|
||||
validatedCache GitHubValidatedAPICacheI
|
||||
}
|
||||
|
||||
@@ -132,7 +131,7 @@ func New(
|
||||
log plog.Logger,
|
||||
withInformer pinnipedcontroller.WithInformerOptionFunc,
|
||||
clock clock.Clock,
|
||||
dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error),
|
||||
dialer ptls.Dialer,
|
||||
validatedCache *cache.Expiring,
|
||||
) controllerlib.Controller {
|
||||
c := gitHubWatcherController{
|
||||
@@ -144,7 +143,7 @@ func New(
|
||||
secretInformer: secretInformer,
|
||||
configMapInformer: configMapInformer,
|
||||
clock: clock,
|
||||
dialFunc: dialFunc,
|
||||
dialer: dialer,
|
||||
validatedCache: NewGitHubValidatedAPICache(validatedCache),
|
||||
}
|
||||
|
||||
@@ -190,7 +189,7 @@ func (c *gitHubWatcherController) Sync(ctx controllerlib.Context) error {
|
||||
var applicationErrors []error
|
||||
validatedUpstreams := make([]upstreamprovider.UpstreamGithubIdentityProviderI, 0, len(actualUpstreams))
|
||||
for _, upstream := range actualUpstreams {
|
||||
validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx, upstream)
|
||||
validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx.Context, upstream)
|
||||
if applicationErr != nil {
|
||||
applicationErrors = append(applicationErrors, applicationErr)
|
||||
} else if validatedUpstream != nil {
|
||||
@@ -298,7 +297,7 @@ func validateOrganizationsPolicy(organizationsSpec *idpv1alpha1.GitHubOrganizati
|
||||
}
|
||||
}
|
||||
|
||||
func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx controllerlib.Context, upstream *idpv1alpha1.GitHubIdentityProvider) (
|
||||
func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx context.Context, upstream *idpv1alpha1.GitHubIdentityProvider) (
|
||||
*upstreamgithub.Provider, // If validated, returns the config
|
||||
error, // This error will only refer to programmatic errors such as inability to perform a Dial or dereference a pointer, not configuration errors
|
||||
) {
|
||||
@@ -332,6 +331,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
|
||||
conditions = append(conditions, tlsConfigCondition)
|
||||
|
||||
githubConnectionCondition, httpClient, githubConnectionErr := c.validateGitHubConnection(
|
||||
ctx,
|
||||
apiHostPort,
|
||||
upstream.Spec.GitHubAPI.Host,
|
||||
caBundle,
|
||||
@@ -350,7 +350,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
|
||||
applicationErrors = append(applicationErrors, fmt.Errorf("expected %d conditions but found %d conditions", countExpectedConditions, len(conditions)))
|
||||
return nil, utilerrors.NewAggregate(applicationErrors)
|
||||
}
|
||||
hadErrorCondition, updateStatusErr := c.updateStatus(ctx.Context, upstream, conditions)
|
||||
hadErrorCondition, updateStatusErr := c.updateStatus(ctx, upstream, conditions)
|
||||
if updateStatusErr != nil {
|
||||
applicationErrors = append(applicationErrors, updateStatusErr)
|
||||
}
|
||||
@@ -454,6 +454,7 @@ func validateHost(specifiedHost *string) (*metav1.Condition, *endpointaddr.HostP
|
||||
}
|
||||
|
||||
func (c *gitHubWatcherController) validateGitHubConnection(
|
||||
ctx context.Context,
|
||||
apiHostPort *endpointaddr.HostPort,
|
||||
specifiedHost *string,
|
||||
caBundle *tlsconfigutil.CABundle,
|
||||
@@ -471,7 +472,10 @@ func (c *gitHubWatcherController) validateGitHubConnection(
|
||||
apiAddress := apiHostPort.Endpoint()
|
||||
|
||||
if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) {
|
||||
conn, tlsDialErr := c.dialFunc("tcp", apiAddress, ptls.Default(caBundle.CertPool()))
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
|
||||
tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, apiAddress, caBundle.CertPool(), c.log)
|
||||
if tlsDialErr != nil {
|
||||
return &metav1.Condition{
|
||||
Type: GitHubConnectionValid,
|
||||
@@ -481,8 +485,6 @@ func (c *gitHubWatcherController) validateGitHubConnection(
|
||||
apiAddress, *specifiedHost, buildDialErrorMessage(tlsDialErr)),
|
||||
}, nil, tlsDialErr
|
||||
}
|
||||
// Any error should be ignored. We have performed a successful Dial, so no need to requeue this Sync.
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
c.validatedCache.MarkAsValidated(apiAddress, caBundle.Hash())
|
||||
|
||||
@@ -6,7 +6,6 @@ package githubupstreamwatcher
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@@ -41,6 +40,7 @@ import (
|
||||
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
|
||||
"go.pinniped.dev/internal/controller/tlsconfigutil"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/crypto/ptls"
|
||||
"go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider"
|
||||
"go.pinniped.dev/internal/federationdomain/upstreamprovider"
|
||||
"go.pinniped.dev/internal/net/phttp"
|
||||
@@ -61,6 +61,32 @@ var (
|
||||
githubIDPKind = idpv1alpha1.SchemeGroupVersion.WithKind("GitHubIdentityProvider")
|
||||
)
|
||||
|
||||
type fakeGithubDialer struct {
|
||||
t *testing.T
|
||||
realAddress string
|
||||
realCertPool *x509.CertPool
|
||||
}
|
||||
|
||||
func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(ctx context.Context, address string, _ *x509.CertPool, logger plog.Logger) error {
|
||||
require.Equal(f.t, "api.github.com:443", address)
|
||||
|
||||
return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(ctx, f.realAddress, f.realCertPool, logger)
|
||||
}
|
||||
|
||||
var _ ptls.Dialer = (*fakeGithubDialer)(nil)
|
||||
|
||||
type allowNoDials struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ context.Context, _ string, _ *x509.CertPool, _ plog.Logger) error {
|
||||
f.t.Errorf("this test should not perform dial")
|
||||
f.t.FailNow()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ptls.Dialer = (*allowNoDials)(nil)
|
||||
|
||||
func TestController(t *testing.T) {
|
||||
require.Equal(t, 6, countExpectedConditions)
|
||||
|
||||
@@ -406,7 +432,7 @@ func TestController(t *testing.T) {
|
||||
name string
|
||||
githubIdentityProviders []runtime.Object
|
||||
secretsAndConfigMaps []runtime.Object
|
||||
mockDialer func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
mockDialer func(*testing.T) ptls.Dialer
|
||||
preexistingValidatedCache []GitHubValidatedAPICacheKey
|
||||
wantErr string
|
||||
wantLogs []string
|
||||
@@ -555,15 +581,13 @@ func TestController(t *testing.T) {
|
||||
return githubIDP
|
||||
}(),
|
||||
},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
require.Equal(t, "api.github.com:443", addr)
|
||||
// don't actually dial github.com to avoid making external network calls in unit test
|
||||
configClone := config.Clone()
|
||||
configClone.RootCAs = goodServerCertPool
|
||||
return tls.Dial(network, goodServerDomain, configClone)
|
||||
return &fakeGithubDialer{
|
||||
t: t,
|
||||
realAddress: goodServerDomain,
|
||||
realCertPool: goodServerCertPool,
|
||||
}
|
||||
},
|
||||
wantResultingCache: []*upstreamgithub.ProviderConfig{
|
||||
@@ -638,15 +662,13 @@ func TestController(t *testing.T) {
|
||||
return githubIDP
|
||||
}(),
|
||||
},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
require.Equal(t, "api.github.com:443", addr)
|
||||
// don't actually dial github.com to avoid making external network calls in unit test
|
||||
configClone := config.Clone()
|
||||
configClone.RootCAs = goodServerCertPool
|
||||
return tls.Dial(network, goodServerDomain, configClone)
|
||||
return &fakeGithubDialer{
|
||||
t: t,
|
||||
realAddress: goodServerDomain,
|
||||
realCertPool: goodServerCertPool,
|
||||
}
|
||||
},
|
||||
wantResultingCache: []*upstreamgithub.ProviderConfig{
|
||||
@@ -721,15 +743,13 @@ func TestController(t *testing.T) {
|
||||
return githubIDP
|
||||
}(),
|
||||
},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
require.Equal(t, "api.github.com:443", addr)
|
||||
// don't actually dial github.com to avoid making external network calls in unit test
|
||||
configClone := config.Clone()
|
||||
configClone.RootCAs = goodServerCertPool
|
||||
return tls.Dial(network, goodServerDomain, configClone)
|
||||
return &fakeGithubDialer{
|
||||
t: t,
|
||||
realAddress: goodServerDomain,
|
||||
realCertPool: goodServerCertPool,
|
||||
}
|
||||
},
|
||||
wantResultingCache: []*upstreamgithub.ProviderConfig{
|
||||
@@ -804,15 +824,13 @@ func TestController(t *testing.T) {
|
||||
return githubIDP
|
||||
}(),
|
||||
},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
require.Equal(t, "api.github.com:443", addr)
|
||||
// don't actually dial github.com to avoid making external network calls in unit test
|
||||
configClone := config.Clone()
|
||||
configClone.RootCAs = goodServerCertPool
|
||||
return tls.Dial(network, goodServerDomain, configClone)
|
||||
return &fakeGithubDialer{
|
||||
t: t,
|
||||
realAddress: goodServerDomain,
|
||||
realCertPool: goodServerCertPool,
|
||||
}
|
||||
},
|
||||
wantResultingCache: []*upstreamgithub.ProviderConfig{
|
||||
@@ -887,15 +905,13 @@ func TestController(t *testing.T) {
|
||||
return githubIDP
|
||||
}(),
|
||||
},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
require.Equal(t, "api.github.com:443", addr)
|
||||
// don't actually dial github.com to avoid making external network calls in unit test
|
||||
configClone := config.Clone()
|
||||
configClone.RootCAs = goodServerCertPool
|
||||
return tls.Dial(network, goodServerDomain, configClone)
|
||||
return &fakeGithubDialer{
|
||||
t: t,
|
||||
realAddress: goodServerDomain,
|
||||
realCertPool: goodServerCertPool,
|
||||
}
|
||||
},
|
||||
wantResultingCache: []*upstreamgithub.ProviderConfig{
|
||||
@@ -1379,14 +1395,10 @@ func TestController(t *testing.T) {
|
||||
name: "happy path with previously validated address/CA Bundle does not validate again",
|
||||
secretsAndConfigMaps: []runtime.Object{goodClientCredentialsSecret},
|
||||
githubIdentityProviders: []runtime.Object{validFilledOutIDP},
|
||||
mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
mockDialer: func(t *testing.T) ptls.Dialer {
|
||||
t.Helper()
|
||||
|
||||
return func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
t.Errorf("this test should not perform dial")
|
||||
t.FailNow()
|
||||
return nil, nil
|
||||
}
|
||||
return &allowNoDials{t: t}
|
||||
},
|
||||
preexistingValidatedCache: []GitHubValidatedAPICacheKey{
|
||||
{
|
||||
@@ -2479,7 +2491,7 @@ func TestController(t *testing.T) {
|
||||
|
||||
gitHubIdentityProviderInformer := supervisorInformers.IDP().V1alpha1().GitHubIdentityProviders()
|
||||
|
||||
dialer := tls.Dial
|
||||
var dialer ptls.Dialer = ptls.NewDialer()
|
||||
if tt.mockDialer != nil {
|
||||
dialer = tt.mockDialer(t)
|
||||
}
|
||||
@@ -2882,7 +2894,7 @@ func TestController_OnlyWantActions(t *testing.T) {
|
||||
logger,
|
||||
controllerlib.WithInformer,
|
||||
frozenClockForLastTransitionTime,
|
||||
tls.Dial,
|
||||
ptls.NewDialer(),
|
||||
cache.NewExpiring(),
|
||||
)
|
||||
|
||||
@@ -3006,7 +3018,7 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) {
|
||||
logger,
|
||||
observableInformers.WithInformer,
|
||||
clock.RealClock{},
|
||||
tls.Dial,
|
||||
ptls.NewDialer(),
|
||||
cache.NewExpiring(),
|
||||
)
|
||||
|
||||
@@ -3063,7 +3075,7 @@ func TestGitHubUpstreamWatcherControllerFilterConfigMaps(t *testing.T) {
|
||||
logger,
|
||||
observableInformers.WithInformer,
|
||||
clock.RealClock{},
|
||||
tls.Dial,
|
||||
ptls.NewDialer(),
|
||||
cache.NewExpiring(),
|
||||
)
|
||||
|
||||
@@ -3120,7 +3132,7 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) {
|
||||
logger,
|
||||
observableInformers.WithInformer,
|
||||
clock.RealClock{},
|
||||
tls.Dial,
|
||||
ptls.NewDialer(),
|
||||
cache.NewExpiring(),
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"go.pinniped.dev/internal/controller/serviceaccounttokencleanup"
|
||||
"go.pinniped.dev/internal/controllerinit"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/crypto/ptls"
|
||||
"go.pinniped.dev/internal/deploymentref"
|
||||
"go.pinniped.dev/internal/downward"
|
||||
"go.pinniped.dev/internal/dynamiccert"
|
||||
@@ -244,6 +245,7 @@ func PrepareControllers(c *Config) (controllerinit.RunnerBuilder, error) { //nol
|
||||
controllerlib.WithInformer,
|
||||
clock.RealClock{},
|
||||
plog.New(),
|
||||
ptls.NewDialer(),
|
||||
),
|
||||
singletonWorker,
|
||||
).
|
||||
|
||||
62
internal/crypto/ptls/dialer.go
Normal file
62
internal/crypto/ptls/dialer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ptls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"time"
|
||||
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
IsReachableAndTLSValidationSucceeds(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
certPool *x509.CertPool,
|
||||
logger plog.Logger,
|
||||
) error
|
||||
}
|
||||
|
||||
type internalDialer struct {
|
||||
}
|
||||
|
||||
func NewDialer() *internalDialer {
|
||||
return &internalDialer{}
|
||||
}
|
||||
|
||||
func (i *internalDialer) IsReachableAndTLSValidationSucceeds(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
certPool *x509.CertPool,
|
||||
logger plog.Logger,
|
||||
) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
_, hasDeadline := ctx.Deadline()
|
||||
if !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
dialer := tls.Dialer{
|
||||
Config: Default(certPool),
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
// Don't wrap this error message since this is just a helper function.
|
||||
return err
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil { // untested
|
||||
// Log it just so that it doesn't completely disappear.
|
||||
logger.Error("Failed to close connection: ", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
173
internal/crypto/ptls/dialer_test.go
Normal file
173
internal/crypto/ptls/dialer_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2024 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Use this package to avoid import loops with internal/testutil/tlsserver
|
||||
package ptls_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pinniped.dev/internal/crypto/ptls"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/internal/testutil/tlsserver"
|
||||
)
|
||||
|
||||
func TestDialer(t *testing.T) {
|
||||
secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil)
|
||||
secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil)
|
||||
insecureServer := httptest.NewServer(nil)
|
||||
|
||||
fakeCert, _, err := testutil.CreateCertificate(time.Now().Add(-1*time.Hour), time.Now().Add(time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fullURL string
|
||||
certPool *x509.CertPool
|
||||
wantError string
|
||||
}{
|
||||
{
|
||||
name: "happy path with TLS-enabled IPv4",
|
||||
fullURL: secureServerIPv4.URL,
|
||||
certPool: bytesToCertPool(secureServerIPv4CA),
|
||||
},
|
||||
{
|
||||
name: "happy path with TLS-enabled IPv6",
|
||||
fullURL: secureServerIPv6.URL,
|
||||
certPool: bytesToCertPool(secureServerIPv6CA),
|
||||
},
|
||||
{
|
||||
name: "returns error when connecting to a non-TLS server",
|
||||
fullURL: insecureServer.URL,
|
||||
wantError: "tls: first record does not look like a TLS handshake",
|
||||
},
|
||||
{
|
||||
name: "returns error when using the wrong bundle",
|
||||
fullURL: secureServerIPv4.URL,
|
||||
certPool: bytesToCertPool(fakeCert),
|
||||
wantError: "tls: failed to verify certificate: x509: certificate signed by unknown authority",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dialer := ptls.NewDialer()
|
||||
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
context.Background(),
|
||||
urlToAddress(t, test.fullURL),
|
||||
test.certPool,
|
||||
logger,
|
||||
)
|
||||
if test.wantError != "" {
|
||||
require.EqualError(t, err, test.wantError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialer_AppliesTimeouts(t *testing.T) {
|
||||
setupHangingServer := func(t *testing.T) string {
|
||||
startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// This causes the dial to hang. I'm actually not quite sure why.
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, startedTLSListener.Close())
|
||||
})
|
||||
|
||||
return startedTLSListener.Addr().String()
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
maxTestDuration time.Duration
|
||||
makeContext func(*testing.T) context.Context
|
||||
}{
|
||||
{
|
||||
name: "timeout after 15s",
|
||||
maxTestDuration: 20 * time.Second,
|
||||
makeContext: func(t *testing.T) context.Context {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
return ctx
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uses 30s timeout without a deadline",
|
||||
maxTestDuration: 35 * time.Second,
|
||||
makeContext: func(t *testing.T) context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uses 30s timeout with a nil context",
|
||||
maxTestDuration: 35 * time.Second,
|
||||
makeContext: func(t *testing.T) context.Context {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
maxTimeForTest := time.After(test.maxTestDuration)
|
||||
testPassed := make(chan bool)
|
||||
go func() {
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
dialer := ptls.NewDialer()
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
test.makeContext(t),
|
||||
setupHangingServer(t),
|
||||
nil,
|
||||
logger,
|
||||
)
|
||||
require.EqualError(t, err, "context deadline exceeded")
|
||||
testPassed <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-maxTimeForTest:
|
||||
t.Errorf("timeout not honored: test did not complete within %s", test.maxTestDuration)
|
||||
t.FailNow()
|
||||
case <-testPassed:
|
||||
t.Log("everything ok!")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func urlToAddress(t *testing.T, urlAsString string) string {
|
||||
u, err := url.Parse(urlAsString)
|
||||
require.NoError(t, err)
|
||||
return u.Host
|
||||
}
|
||||
|
||||
func bytesToCertPool(ca []byte) *x509.CertPool {
|
||||
x509CertPool := x509.NewCertPool()
|
||||
x509CertPool.AppendCertsFromPEM(ca)
|
||||
return x509CertPool
|
||||
}
|
||||
@@ -342,7 +342,7 @@ func prepareControllers(
|
||||
plog.New(),
|
||||
controllerlib.WithInformer,
|
||||
clock.RealClock{},
|
||||
tls.Dial,
|
||||
ptls.NewDialer(),
|
||||
cache.NewExpiring(),
|
||||
),
|
||||
singletonWorker).
|
||||
|
||||
Reference in New Issue
Block a user