diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go index eacddfc4c..edd11b423 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go @@ -6,7 +6,6 @@ package webhookcachefiller import ( "context" - "crypto/tls" "crypto/x509" "fmt" "net/url" @@ -77,6 +76,7 @@ func New( withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, log plog.Logger, + dialer ptls.Dialer, ) controllerlib.Controller { return controllerlib.New( controllerlib.Config{ @@ -90,6 +90,7 @@ func New( configMapInformer: configMapInformer, clock: clock, log: log.WithName(controllerName), + dialer: dialer, }, }, withInformer( @@ -125,6 +126,7 @@ type webhookCacheFillerController struct { client conciergeclientset.Interface clock clock.Clock log plog.Logger + dialer ptls.Dialer } // Sync implements controllerlib.Syncer. @@ -197,7 +199,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co ) } else { // Run all remaining validations. - a, moreConditions, moreErrs := c.doExpensiveValidations(webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger) + a, moreConditions, moreErrs := c.doExpensiveValidations(ctx, webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger) newWebhookAuthenticatorForCache = a conditions = append(conditions, moreConditions...) errs = append(errs, moreErrs...) @@ -236,6 +238,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co } func (c *webhookCacheFillerController) doExpensiveValidations( + ctx context.Context, webhookAuthenticator *authenticationv1alpha1.WebhookAuthenticator, endpointHostPort *endpointaddr.HostPort, caBundle *tlsconfigutil.CABundle, @@ -246,7 +249,7 @@ func (c *webhookCacheFillerController) doExpensiveValidations( var conditions []*metav1.Condition var errs []error - conditions, tlsNegotiateErr := c.validateConnection(caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger) + conditions, tlsNegotiateErr := c.validateConnection(ctx, caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger) errs = append(errs, tlsNegotiateErr) okSoFar = okSoFar && tlsNegotiateErr == nil @@ -412,6 +415,7 @@ func successfulWebhookConnectionValidCondition() *metav1.Condition { } func (c *webhookCacheFillerController) validateConnection( + ctx context.Context, certPool *x509.CertPool, endpointHostPort *endpointaddr.HostPort, conditions []*metav1.Condition, @@ -428,11 +432,13 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, nil } - conn, err := tls.Dial("tcp", endpointHostPort.Endpoint(), ptls.Default(certPool)) + dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + defer dialCancel() + err := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, endpointHostPort.Endpoint(), certPool, logger) if err != nil { errText := "cannot dial server" - msg := fmt.Sprintf("%s: %s", errText, err.Error()) + msg := fmt.Sprintf("%s: %s", errText, err) conditions = append(conditions, &metav1.Condition{ Type: typeWebhookConnectionValid, Status: metav1.ConditionFalse, @@ -442,13 +448,6 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, fmt.Errorf("%s: %w", errText, err) } - // this error should never be significant - err = conn.Close() - if err != nil { - // no unit test for this failure - logger.Error("error closing dialer", err) - } - conditions = append(conditions, successfulWebhookConnectionValidCondition()) return conditions, nil } diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go index 24b2b1c3a..e75e51607 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go @@ -1934,7 +1934,8 @@ func TestController(t *testing.T) { kubeInformers.Core().V1().ConfigMaps(), controllerlib.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2177,7 +2178,8 @@ func TestControllerFilterSecret(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.Secret{} filter := observableInformers.GetFilterForInformer(secretInformer) @@ -2238,7 +2240,8 @@ func TestControllerFilterConfigMap(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.ConfigMap{} filter := observableInformers.GetFilterForInformer(configMapInformer) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index e90c1fc66..71cd980f3 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "context" - "crypto/tls" "errors" "fmt" "net" @@ -117,7 +116,7 @@ type gitHubWatcherController struct { secretInformer corev1informers.SecretInformer configMapInformer corev1informers.ConfigMapInformer clock clock.Clock - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error) + dialer ptls.Dialer validatedCache GitHubValidatedAPICacheI } @@ -132,7 +131,7 @@ func New( log plog.Logger, withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error), + dialer ptls.Dialer, validatedCache *cache.Expiring, ) controllerlib.Controller { c := gitHubWatcherController{ @@ -144,7 +143,7 @@ func New( secretInformer: secretInformer, configMapInformer: configMapInformer, clock: clock, - dialFunc: dialFunc, + dialer: dialer, validatedCache: NewGitHubValidatedAPICache(validatedCache), } @@ -190,7 +189,7 @@ func (c *gitHubWatcherController) Sync(ctx controllerlib.Context) error { var applicationErrors []error validatedUpstreams := make([]upstreamprovider.UpstreamGithubIdentityProviderI, 0, len(actualUpstreams)) for _, upstream := range actualUpstreams { - validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx, upstream) + validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx.Context, upstream) if applicationErr != nil { applicationErrors = append(applicationErrors, applicationErr) } else if validatedUpstream != nil { @@ -298,7 +297,7 @@ func validateOrganizationsPolicy(organizationsSpec *idpv1alpha1.GitHubOrganizati } } -func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx controllerlib.Context, upstream *idpv1alpha1.GitHubIdentityProvider) ( +func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx context.Context, upstream *idpv1alpha1.GitHubIdentityProvider) ( *upstreamgithub.Provider, // If validated, returns the config error, // This error will only refer to programmatic errors such as inability to perform a Dial or dereference a pointer, not configuration errors ) { @@ -332,6 +331,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro conditions = append(conditions, tlsConfigCondition) githubConnectionCondition, httpClient, githubConnectionErr := c.validateGitHubConnection( + ctx, apiHostPort, upstream.Spec.GitHubAPI.Host, caBundle, @@ -350,7 +350,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro applicationErrors = append(applicationErrors, fmt.Errorf("expected %d conditions but found %d conditions", countExpectedConditions, len(conditions))) return nil, utilerrors.NewAggregate(applicationErrors) } - hadErrorCondition, updateStatusErr := c.updateStatus(ctx.Context, upstream, conditions) + hadErrorCondition, updateStatusErr := c.updateStatus(ctx, upstream, conditions) if updateStatusErr != nil { applicationErrors = append(applicationErrors, updateStatusErr) } @@ -454,6 +454,7 @@ func validateHost(specifiedHost *string) (*metav1.Condition, *endpointaddr.HostP } func (c *gitHubWatcherController) validateGitHubConnection( + ctx context.Context, apiHostPort *endpointaddr.HostPort, specifiedHost *string, caBundle *tlsconfigutil.CABundle, @@ -471,7 +472,10 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress := apiHostPort.Endpoint() if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) { - conn, tlsDialErr := c.dialFunc("tcp", apiAddress, ptls.Default(caBundle.CertPool())) + dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + defer dialCancel() + + tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, apiAddress, caBundle.CertPool(), c.log) if tlsDialErr != nil { return &metav1.Condition{ Type: GitHubConnectionValid, @@ -481,8 +485,6 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress, *specifiedHost, buildDialErrorMessage(tlsDialErr)), }, nil, tlsDialErr } - // Any error should be ignored. We have performed a successful Dial, so no need to requeue this Sync. - _ = conn.Close() } c.validatedCache.MarkAsValidated(apiAddress, caBundle.Hash()) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go index 63a2cd50b..74e5e6151 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "bytes" "context" - "crypto/tls" "crypto/x509" "encoding/base64" "fmt" @@ -41,6 +40,7 @@ import ( "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/tlsconfigutil" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/net/phttp" @@ -61,6 +61,32 @@ var ( githubIDPKind = idpv1alpha1.SchemeGroupVersion.WithKind("GitHubIdentityProvider") ) +type fakeGithubDialer struct { + t *testing.T + realAddress string + realCertPool *x509.CertPool +} + +func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(ctx context.Context, address string, _ *x509.CertPool, logger plog.Logger) error { + require.Equal(f.t, "api.github.com:443", address) + + return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(ctx, f.realAddress, f.realCertPool, logger) +} + +var _ ptls.Dialer = (*fakeGithubDialer)(nil) + +type allowNoDials struct { + t *testing.T +} + +func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ context.Context, _ string, _ *x509.CertPool, _ plog.Logger) error { + f.t.Errorf("this test should not perform dial") + f.t.FailNow() + return nil +} + +var _ ptls.Dialer = (*allowNoDials)(nil) + func TestController(t *testing.T) { require.Equal(t, 6, countExpectedConditions) @@ -406,7 +432,7 @@ func TestController(t *testing.T) { name string githubIdentityProviders []runtime.Object secretsAndConfigMaps []runtime.Object - mockDialer func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) + mockDialer func(*testing.T) ptls.Dialer preexistingValidatedCache []GitHubValidatedAPICacheKey wantErr string wantLogs []string @@ -555,15 +581,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -638,15 +662,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -721,15 +743,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -804,15 +824,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -887,15 +905,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -1379,14 +1395,10 @@ func TestController(t *testing.T) { name: "happy path with previously validated address/CA Bundle does not validate again", secretsAndConfigMaps: []runtime.Object{goodClientCredentialsSecret}, githubIdentityProviders: []runtime.Object{validFilledOutIDP}, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - t.Errorf("this test should not perform dial") - t.FailNow() - return nil, nil - } + return &allowNoDials{t: t} }, preexistingValidatedCache: []GitHubValidatedAPICacheKey{ { @@ -2479,7 +2491,7 @@ func TestController(t *testing.T) { gitHubIdentityProviderInformer := supervisorInformers.IDP().V1alpha1().GitHubIdentityProviders() - dialer := tls.Dial + var dialer ptls.Dialer = ptls.NewDialer() if tt.mockDialer != nil { dialer = tt.mockDialer(t) } @@ -2882,7 +2894,7 @@ func TestController_OnlyWantActions(t *testing.T) { logger, controllerlib.WithInformer, frozenClockForLastTransitionTime, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3006,7 +3018,7 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3063,7 +3075,7 @@ func TestGitHubUpstreamWatcherControllerFilterConfigMaps(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3120,7 +3132,7 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) diff --git a/internal/controllermanager/prepare_controllers.go b/internal/controllermanager/prepare_controllers.go index 6d5144216..0dfc2396a 100644 --- a/internal/controllermanager/prepare_controllers.go +++ b/internal/controllermanager/prepare_controllers.go @@ -28,6 +28,7 @@ import ( "go.pinniped.dev/internal/controller/serviceaccounttokencleanup" "go.pinniped.dev/internal/controllerinit" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/deploymentref" "go.pinniped.dev/internal/downward" "go.pinniped.dev/internal/dynamiccert" @@ -244,6 +245,7 @@ func PrepareControllers(c *Config) (controllerinit.RunnerBuilder, error) { //nol controllerlib.WithInformer, clock.RealClock{}, plog.New(), + ptls.NewDialer(), ), singletonWorker, ). diff --git a/internal/crypto/ptls/dialer.go b/internal/crypto/ptls/dialer.go new file mode 100644 index 000000000..9e86bb386 --- /dev/null +++ b/internal/crypto/ptls/dialer.go @@ -0,0 +1,62 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ptls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "time" + + "go.pinniped.dev/internal/plog" +) + +type Dialer interface { + IsReachableAndTLSValidationSucceeds( + ctx context.Context, + address string, + certPool *x509.CertPool, + logger plog.Logger, + ) error +} + +type internalDialer struct { +} + +func NewDialer() *internalDialer { + return &internalDialer{} +} + +func (i *internalDialer) IsReachableAndTLSValidationSucceeds( + ctx context.Context, + address string, + certPool *x509.CertPool, + logger plog.Logger, +) error { + if ctx == nil { + ctx = context.Background() + } + + _, hasDeadline := ctx.Deadline() + if !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + + dialer := tls.Dialer{ + Config: Default(certPool), + } + conn, err := dialer.DialContext(ctx, "tcp", address) + if err != nil { + // Don't wrap this error message since this is just a helper function. + return err + } + err = conn.Close() + if err != nil { // untested + // Log it just so that it doesn't completely disappear. + logger.Error("Failed to close connection: ", err) + } + return nil +} diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go new file mode 100644 index 000000000..42ab985d1 --- /dev/null +++ b/internal/crypto/ptls/dialer_test.go @@ -0,0 +1,173 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Use this package to avoid import loops with internal/testutil/tlsserver +package ptls_test + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/crypto/ptls" + "go.pinniped.dev/internal/plog" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/tlsserver" +) + +func TestDialer(t *testing.T) { + secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil) + secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil) + insecureServer := httptest.NewServer(nil) + + fakeCert, _, err := testutil.CreateCertificate(time.Now().Add(-1*time.Hour), time.Now().Add(time.Hour)) + require.NoError(t, err) + + tests := []struct { + name string + fullURL string + certPool *x509.CertPool + wantError string + }{ + { + name: "happy path with TLS-enabled IPv4", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(secureServerIPv4CA), + }, + { + name: "happy path with TLS-enabled IPv6", + fullURL: secureServerIPv6.URL, + certPool: bytesToCertPool(secureServerIPv6CA), + }, + { + name: "returns error when connecting to a non-TLS server", + fullURL: insecureServer.URL, + wantError: "tls: first record does not look like a TLS handshake", + }, + { + name: "returns error when using the wrong bundle", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(fakeCert), + wantError: "tls: failed to verify certificate: x509: certificate signed by unknown authority", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + dialer := ptls.NewDialer() + + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + err := dialer.IsReachableAndTLSValidationSucceeds( + context.Background(), + urlToAddress(t, test.fullURL), + test.certPool, + logger, + ) + if test.wantError != "" { + require.EqualError(t, err, test.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDialer_AppliesTimeouts(t *testing.T) { + setupHangingServer := func(t *testing.T) string { + startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // This causes the dial to hang. I'm actually not quite sure why. + return nil, nil + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, startedTLSListener.Close()) + }) + + return startedTLSListener.Addr().String() + } + + tests := []struct { + name string + maxTestDuration time.Duration + makeContext func(*testing.T) context.Context + }{ + { + name: "timeout after 15s", + maxTestDuration: 20 * time.Second, + makeContext: func(t *testing.T) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + return ctx + }, + }, + { + name: "uses 30s timeout without a deadline", + maxTestDuration: 35 * time.Second, + makeContext: func(t *testing.T) context.Context { + return context.Background() + }, + }, + { + name: "uses 30s timeout with a nil context", + maxTestDuration: 35 * time.Second, + makeContext: func(t *testing.T) context.Context { + return nil + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + maxTimeForTest := time.After(test.maxTestDuration) + testPassed := make(chan bool) + go func() { + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + dialer := ptls.NewDialer() + err := dialer.IsReachableAndTLSValidationSucceeds( + test.makeContext(t), + setupHangingServer(t), + nil, + logger, + ) + require.EqualError(t, err, "context deadline exceeded") + testPassed <- true + }() + + select { + case <-maxTimeForTest: + t.Errorf("timeout not honored: test did not complete within %s", test.maxTestDuration) + t.FailNow() + case <-testPassed: + t.Log("everything ok!") + } + }) + } +} + +func urlToAddress(t *testing.T, urlAsString string) string { + u, err := url.Parse(urlAsString) + require.NoError(t, err) + return u.Host +} + +func bytesToCertPool(ca []byte) *x509.CertPool { + x509CertPool := x509.NewCertPool() + x509CertPool.AppendCertsFromPEM(ca) + return x509CertPool +} diff --git a/internal/supervisor/server/server.go b/internal/supervisor/server/server.go index 5aa4270af..ae2b65167 100644 --- a/internal/supervisor/server/server.go +++ b/internal/supervisor/server/server.go @@ -342,7 +342,7 @@ func prepareControllers( plog.New(), controllerlib.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ), singletonWorker).