From 76a116641fa2dc0da4a540bc190745939087856a Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Tue, 3 Sep 2024 14:45:14 -0500 Subject: [PATCH 1/4] Add ptls.Dialer to provide some common configuration for tls.Dial operations --- .../webhookcachefiller/webhookcachefiller.go | 15 +- .../webhookcachefiller_test.go | 9 +- .../github_upstream_watcher.go | 11 +- .../github_upstream_watcher_test.go | 108 +++++++----- .../controllermanager/prepare_controllers.go | 2 + internal/crypto/ptls/dialer.go | 58 +++++++ internal/crypto/ptls/dialer_test.go | 164 ++++++++++++++++++ internal/supervisor/server/server.go | 2 +- 8 files changed, 300 insertions(+), 69 deletions(-) create mode 100644 internal/crypto/ptls/dialer.go create mode 100644 internal/crypto/ptls/dialer_test.go diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go index eacddfc4c..5f7344171 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go @@ -6,7 +6,6 @@ package webhookcachefiller import ( "context" - "crypto/tls" "crypto/x509" "fmt" "net/url" @@ -77,6 +76,7 @@ func New( withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, log plog.Logger, + dialer ptls.Dialer, ) controllerlib.Controller { return controllerlib.New( controllerlib.Config{ @@ -90,6 +90,7 @@ func New( configMapInformer: configMapInformer, clock: clock, log: log.WithName(controllerName), + dialer: dialer, }, }, withInformer( @@ -125,6 +126,7 @@ type webhookCacheFillerController struct { client conciergeclientset.Interface clock clock.Clock log plog.Logger + dialer ptls.Dialer } // Sync implements controllerlib.Syncer. @@ -428,11 +430,11 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, nil } - conn, err := tls.Dial("tcp", endpointHostPort.Endpoint(), ptls.Default(certPool)) + err := c.dialer.IsReachableAndTLSValidationSucceeds(endpointHostPort.Endpoint(), certPool, logger) if err != nil { errText := "cannot dial server" - msg := fmt.Sprintf("%s: %s", errText, err.Error()) + msg := fmt.Sprintf("%s: %s", errText, err) conditions = append(conditions, &metav1.Condition{ Type: typeWebhookConnectionValid, Status: metav1.ConditionFalse, @@ -442,13 +444,6 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, fmt.Errorf("%s: %w", errText, err) } - // this error should never be significant - err = conn.Close() - if err != nil { - // no unit test for this failure - logger.Error("error closing dialer", err) - } - conditions = append(conditions, successfulWebhookConnectionValidCondition()) return conditions, nil } diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go index 24b2b1c3a..e75e51607 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go @@ -1934,7 +1934,8 @@ func TestController(t *testing.T) { kubeInformers.Core().V1().ConfigMaps(), controllerlib.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2177,7 +2178,8 @@ func TestControllerFilterSecret(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.Secret{} filter := observableInformers.GetFilterForInformer(secretInformer) @@ -2238,7 +2240,8 @@ func TestControllerFilterConfigMap(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.ConfigMap{} filter := observableInformers.GetFilterForInformer(configMapInformer) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index e90c1fc66..5326ec366 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "context" - "crypto/tls" "errors" "fmt" "net" @@ -117,7 +116,7 @@ type gitHubWatcherController struct { secretInformer corev1informers.SecretInformer configMapInformer corev1informers.ConfigMapInformer clock clock.Clock - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error) + dialer ptls.Dialer validatedCache GitHubValidatedAPICacheI } @@ -132,7 +131,7 @@ func New( log plog.Logger, withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error), + dialer ptls.Dialer, validatedCache *cache.Expiring, ) controllerlib.Controller { c := gitHubWatcherController{ @@ -144,7 +143,7 @@ func New( secretInformer: secretInformer, configMapInformer: configMapInformer, clock: clock, - dialFunc: dialFunc, + dialer: dialer, validatedCache: NewGitHubValidatedAPICache(validatedCache), } @@ -471,7 +470,7 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress := apiHostPort.Endpoint() if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) { - conn, tlsDialErr := c.dialFunc("tcp", apiAddress, ptls.Default(caBundle.CertPool())) + tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(apiAddress, caBundle.CertPool(), c.log) if tlsDialErr != nil { return &metav1.Condition{ Type: GitHubConnectionValid, @@ -481,8 +480,6 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress, *specifiedHost, buildDialErrorMessage(tlsDialErr)), }, nil, tlsDialErr } - // Any error should be ignored. We have performed a successful Dial, so no need to requeue this Sync. - _ = conn.Close() } c.validatedCache.MarkAsValidated(apiAddress, caBundle.Hash()) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go index 63a2cd50b..d070f463e 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "bytes" "context" - "crypto/tls" "crypto/x509" "encoding/base64" "fmt" @@ -41,6 +40,7 @@ import ( "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/tlsconfigutil" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/net/phttp" @@ -61,6 +61,32 @@ var ( githubIDPKind = idpv1alpha1.SchemeGroupVersion.WithKind("GitHubIdentityProvider") ) +type fakeGithubDialer struct { + t *testing.T + realAddress string + realCertPool *x509.CertPool +} + +func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(address string, _ *x509.CertPool, logger ptls.ErrorOnlyLogger) error { + require.Equal(f.t, "api.github.com:443", address) + + return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(f.realAddress, f.realCertPool, logger) +} + +var _ ptls.Dialer = (*fakeGithubDialer)(nil) + +type allowNoDials struct { + t *testing.T +} + +func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ string, _ *x509.CertPool, _ ptls.ErrorOnlyLogger) error { + f.t.Errorf("this test should not perform dial") + f.t.FailNow() + return nil +} + +var _ ptls.Dialer = (*allowNoDials)(nil) + func TestController(t *testing.T) { require.Equal(t, 6, countExpectedConditions) @@ -406,7 +432,7 @@ func TestController(t *testing.T) { name string githubIdentityProviders []runtime.Object secretsAndConfigMaps []runtime.Object - mockDialer func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) + mockDialer func(*testing.T) ptls.Dialer preexistingValidatedCache []GitHubValidatedAPICacheKey wantErr string wantLogs []string @@ -555,15 +581,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -638,15 +662,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -721,15 +743,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -804,15 +824,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -887,15 +905,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -1379,14 +1395,10 @@ func TestController(t *testing.T) { name: "happy path with previously validated address/CA Bundle does not validate again", secretsAndConfigMaps: []runtime.Object{goodClientCredentialsSecret}, githubIdentityProviders: []runtime.Object{validFilledOutIDP}, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - t.Errorf("this test should not perform dial") - t.FailNow() - return nil, nil - } + return &allowNoDials{t: t} }, preexistingValidatedCache: []GitHubValidatedAPICacheKey{ { @@ -2479,7 +2491,7 @@ func TestController(t *testing.T) { gitHubIdentityProviderInformer := supervisorInformers.IDP().V1alpha1().GitHubIdentityProviders() - dialer := tls.Dial + var dialer ptls.Dialer = ptls.NewDialer() if tt.mockDialer != nil { dialer = tt.mockDialer(t) } @@ -2882,7 +2894,7 @@ func TestController_OnlyWantActions(t *testing.T) { logger, controllerlib.WithInformer, frozenClockForLastTransitionTime, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3006,7 +3018,7 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3063,7 +3075,7 @@ func TestGitHubUpstreamWatcherControllerFilterConfigMaps(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3120,7 +3132,7 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) diff --git a/internal/controllermanager/prepare_controllers.go b/internal/controllermanager/prepare_controllers.go index 6d5144216..0dfc2396a 100644 --- a/internal/controllermanager/prepare_controllers.go +++ b/internal/controllermanager/prepare_controllers.go @@ -28,6 +28,7 @@ import ( "go.pinniped.dev/internal/controller/serviceaccounttokencleanup" "go.pinniped.dev/internal/controllerinit" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/deploymentref" "go.pinniped.dev/internal/downward" "go.pinniped.dev/internal/dynamiccert" @@ -244,6 +245,7 @@ func PrepareControllers(c *Config) (controllerinit.RunnerBuilder, error) { //nol controllerlib.WithInformer, clock.RealClock{}, plog.New(), + ptls.NewDialer(), ), singletonWorker, ). diff --git a/internal/crypto/ptls/dialer.go b/internal/crypto/ptls/dialer.go new file mode 100644 index 000000000..ba0309338 --- /dev/null +++ b/internal/crypto/ptls/dialer.go @@ -0,0 +1,58 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ptls + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" +) + +type Dialer interface { + IsReachableAndTLSValidationSucceeds( + address string, + certPool *x509.CertPool, + logger ErrorOnlyLogger, + ) error +} + +type ErrorOnlyLogger interface { + Error(msg string, err error, keysAndValues ...any) +} + +type internalDialer struct { + dialer *net.Dialer +} + +func NewDialer() *internalDialer { + return &internalDialer{ + dialer: &net.Dialer{ + Timeout: 15 * time.Second, + }, + } +} + +func (i *internalDialer) WithTimeout(timeout time.Duration) Dialer { + i.dialer.Timeout = timeout + return i +} + +func (i *internalDialer) IsReachableAndTLSValidationSucceeds( + address string, + certPool *x509.CertPool, + logger ErrorOnlyLogger, +) error { + connection, err := tls.DialWithDialer(i.dialer, "tcp", address, Default(certPool)) + if err != nil { + // Don't wrap this error message since this is just a helper function. + return err + } + err = connection.Close() + if err != nil { // untested + // Log it just so that it doesn't completely disappear. + logger.Error("Failed to close connection: ", err) + } + return nil +} diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go new file mode 100644 index 000000000..f40de7a6c --- /dev/null +++ b/internal/crypto/ptls/dialer_test.go @@ -0,0 +1,164 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Use this package to avoid import loops with internal/testutil/tlsserver +package ptls_test + +import ( + "crypto/tls" + "crypto/x509" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/crypto/ptls" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/tlsserver" +) + +type fakeerroronlylogger struct { +} + +func (_ *fakeerroronlylogger) Error(msg string, err error, keysAndValues ...any) { + // NOOP +} + +var _ ptls.ErrorOnlyLogger = (*fakeerroronlylogger)(nil) + +func TestDialer(t *testing.T) { + secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil) + secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil) + insecureServer := httptest.NewServer(nil) + + fakeCert, _, err := testutil.CreateCertificate(time.Now().Add(-1*time.Hour), time.Now().Add(time.Hour)) + require.NoError(t, err) + + tests := []struct { + name string + fullURL string + certPool *x509.CertPool + wantError string + }{ + { + name: "happy path with TLS-enabled IPv4", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(secureServerIPv4CA), + }, + { + name: "happy path with TLS-enabled IPv6", + fullURL: secureServerIPv6.URL, + certPool: bytesToCertPool(secureServerIPv6CA), + }, + { + name: "returns error when connecting to a non-TLS server", + fullURL: insecureServer.URL, + wantError: "tls: first record does not look like a TLS handshake", + }, + { + name: "returns error when using the wrong bundle", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(fakeCert), + wantError: "tls: failed to verify certificate: x509: certificate signed by unknown authority", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + dialer := ptls.NewDialer() + + err := dialer.IsReachableAndTLSValidationSucceeds( + urlToAddress(t, test.fullURL), + test.certPool, + &fakeerroronlylogger{}, + ) + if test.wantError != "" { + require.EqualError(t, err, test.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDialer_TimeoutAfter15s(t *testing.T) { + t.Parallel() + + dialer := ptls.NewDialer() + + timeout := time.After(30 * time.Second) + testDone := make(chan bool) + go func() { + err := dialer.IsReachableAndTLSValidationSucceeds( + setupHangingServer(t), + nil, + &fakeerroronlylogger{}, + ) + require.EqualError(t, err, "context deadline exceeded") + testDone <- true + }() + + select { + case <-timeout: + t.Errorf("test did not complete within 30 seconds") + t.FailNow() + case <-testDone: + t.Log("everything ok!") + } +} + +func TestDialer_WithCustomTimeTimeoutAfter2s(t *testing.T) { + t.Parallel() + + dialer := ptls.NewDialer().WithTimeout(2 * time.Second) + + timeout := time.After(5 * time.Second) + testDone := make(chan bool) + go func() { + err := dialer.IsReachableAndTLSValidationSucceeds( + setupHangingServer(t), + nil, + &fakeerroronlylogger{}, + ) + require.EqualError(t, err, "context deadline exceeded") + testDone <- true + }() + + select { + case <-timeout: + t.Errorf("test did not complete within 5 seconds") + t.FailNow() + case <-testDone: + t.Log("everything ok!") + } +} + +func setupHangingServer(t *testing.T) string { + startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // This causes the dial to hang. I'm actually not quite sure why. + return nil, nil + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, startedTLSListener.Close()) + }) + + return startedTLSListener.Addr().String() +} + +func urlToAddress(t *testing.T, urlAsString string) string { + u, err := url.Parse(urlAsString) + require.NoError(t, err) + return u.Host +} + +func bytesToCertPool(ca []byte) *x509.CertPool { + x509CertPool := x509.NewCertPool() + x509CertPool.AppendCertsFromPEM(ca) + return x509CertPool +} diff --git a/internal/supervisor/server/server.go b/internal/supervisor/server/server.go index 5aa4270af..ae2b65167 100644 --- a/internal/supervisor/server/server.go +++ b/internal/supervisor/server/server.go @@ -342,7 +342,7 @@ func prepareControllers( plog.New(), controllerlib.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ), singletonWorker). From f7fd209f29423b8d5a3b23acf035c3b403081ad6 Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Thu, 19 Sep 2024 16:29:56 -0500 Subject: [PATCH 2/4] Address PR feedback --- .../webhookcachefiller/webhookcachefiller.go | 10 +- .../github_upstream_watcher.go | 13 ++- .../github_upstream_watcher_test.go | 6 +- internal/crypto/ptls/dialer.go | 44 +++++---- internal/crypto/ptls/dialer_test.go | 97 +++++++++++++------ 5 files changed, 112 insertions(+), 58 deletions(-) diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go index 5f7344171..edd11b423 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go @@ -199,7 +199,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co ) } else { // Run all remaining validations. - a, moreConditions, moreErrs := c.doExpensiveValidations(webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger) + a, moreConditions, moreErrs := c.doExpensiveValidations(ctx, webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger) newWebhookAuthenticatorForCache = a conditions = append(conditions, moreConditions...) errs = append(errs, moreErrs...) @@ -238,6 +238,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co } func (c *webhookCacheFillerController) doExpensiveValidations( + ctx context.Context, webhookAuthenticator *authenticationv1alpha1.WebhookAuthenticator, endpointHostPort *endpointaddr.HostPort, caBundle *tlsconfigutil.CABundle, @@ -248,7 +249,7 @@ func (c *webhookCacheFillerController) doExpensiveValidations( var conditions []*metav1.Condition var errs []error - conditions, tlsNegotiateErr := c.validateConnection(caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger) + conditions, tlsNegotiateErr := c.validateConnection(ctx, caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger) errs = append(errs, tlsNegotiateErr) okSoFar = okSoFar && tlsNegotiateErr == nil @@ -414,6 +415,7 @@ func successfulWebhookConnectionValidCondition() *metav1.Condition { } func (c *webhookCacheFillerController) validateConnection( + ctx context.Context, certPool *x509.CertPool, endpointHostPort *endpointaddr.HostPort, conditions []*metav1.Condition, @@ -430,7 +432,9 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, nil } - err := c.dialer.IsReachableAndTLSValidationSucceeds(endpointHostPort.Endpoint(), certPool, logger) + dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + defer dialCancel() + err := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, endpointHostPort.Endpoint(), certPool, logger) if err != nil { errText := "cannot dial server" diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index 5326ec366..71cd980f3 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -189,7 +189,7 @@ func (c *gitHubWatcherController) Sync(ctx controllerlib.Context) error { var applicationErrors []error validatedUpstreams := make([]upstreamprovider.UpstreamGithubIdentityProviderI, 0, len(actualUpstreams)) for _, upstream := range actualUpstreams { - validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx, upstream) + validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx.Context, upstream) if applicationErr != nil { applicationErrors = append(applicationErrors, applicationErr) } else if validatedUpstream != nil { @@ -297,7 +297,7 @@ func validateOrganizationsPolicy(organizationsSpec *idpv1alpha1.GitHubOrganizati } } -func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx controllerlib.Context, upstream *idpv1alpha1.GitHubIdentityProvider) ( +func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx context.Context, upstream *idpv1alpha1.GitHubIdentityProvider) ( *upstreamgithub.Provider, // If validated, returns the config error, // This error will only refer to programmatic errors such as inability to perform a Dial or dereference a pointer, not configuration errors ) { @@ -331,6 +331,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro conditions = append(conditions, tlsConfigCondition) githubConnectionCondition, httpClient, githubConnectionErr := c.validateGitHubConnection( + ctx, apiHostPort, upstream.Spec.GitHubAPI.Host, caBundle, @@ -349,7 +350,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro applicationErrors = append(applicationErrors, fmt.Errorf("expected %d conditions but found %d conditions", countExpectedConditions, len(conditions))) return nil, utilerrors.NewAggregate(applicationErrors) } - hadErrorCondition, updateStatusErr := c.updateStatus(ctx.Context, upstream, conditions) + hadErrorCondition, updateStatusErr := c.updateStatus(ctx, upstream, conditions) if updateStatusErr != nil { applicationErrors = append(applicationErrors, updateStatusErr) } @@ -453,6 +454,7 @@ func validateHost(specifiedHost *string) (*metav1.Condition, *endpointaddr.HostP } func (c *gitHubWatcherController) validateGitHubConnection( + ctx context.Context, apiHostPort *endpointaddr.HostPort, specifiedHost *string, caBundle *tlsconfigutil.CABundle, @@ -470,7 +472,10 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress := apiHostPort.Endpoint() if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) { - tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(apiAddress, caBundle.CertPool(), c.log) + dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + defer dialCancel() + + tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, apiAddress, caBundle.CertPool(), c.log) if tlsDialErr != nil { return &metav1.Condition{ Type: GitHubConnectionValid, diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go index d070f463e..74e5e6151 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go @@ -67,10 +67,10 @@ type fakeGithubDialer struct { realCertPool *x509.CertPool } -func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(address string, _ *x509.CertPool, logger ptls.ErrorOnlyLogger) error { +func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(ctx context.Context, address string, _ *x509.CertPool, logger plog.Logger) error { require.Equal(f.t, "api.github.com:443", address) - return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(f.realAddress, f.realCertPool, logger) + return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(ctx, f.realAddress, f.realCertPool, logger) } var _ ptls.Dialer = (*fakeGithubDialer)(nil) @@ -79,7 +79,7 @@ type allowNoDials struct { t *testing.T } -func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ string, _ *x509.CertPool, _ ptls.ErrorOnlyLogger) error { +func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ context.Context, _ string, _ *x509.CertPool, _ plog.Logger) error { f.t.Errorf("this test should not perform dial") f.t.FailNow() return nil diff --git a/internal/crypto/ptls/dialer.go b/internal/crypto/ptls/dialer.go index ba0309338..9e86bb386 100644 --- a/internal/crypto/ptls/dialer.go +++ b/internal/crypto/ptls/dialer.go @@ -4,52 +4,56 @@ package ptls import ( + "context" "crypto/tls" "crypto/x509" - "net" "time" + + "go.pinniped.dev/internal/plog" ) type Dialer interface { IsReachableAndTLSValidationSucceeds( + ctx context.Context, address string, certPool *x509.CertPool, - logger ErrorOnlyLogger, + logger plog.Logger, ) error } -type ErrorOnlyLogger interface { - Error(msg string, err error, keysAndValues ...any) -} - type internalDialer struct { - dialer *net.Dialer } func NewDialer() *internalDialer { - return &internalDialer{ - dialer: &net.Dialer{ - Timeout: 15 * time.Second, - }, - } -} - -func (i *internalDialer) WithTimeout(timeout time.Duration) Dialer { - i.dialer.Timeout = timeout - return i + return &internalDialer{} } func (i *internalDialer) IsReachableAndTLSValidationSucceeds( + ctx context.Context, address string, certPool *x509.CertPool, - logger ErrorOnlyLogger, + logger plog.Logger, ) error { - connection, err := tls.DialWithDialer(i.dialer, "tcp", address, Default(certPool)) + if ctx == nil { + ctx = context.Background() + } + + _, hasDeadline := ctx.Deadline() + if !hasDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + + dialer := tls.Dialer{ + Config: Default(certPool), + } + conn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { // Don't wrap this error message since this is just a helper function. return err } - err = connection.Close() + err = conn.Close() if err != nil { // untested // Log it just so that it doesn't completely disappear. logger.Error("Failed to close connection: ", err) diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go index f40de7a6c..8bb15e917 100644 --- a/internal/crypto/ptls/dialer_test.go +++ b/internal/crypto/ptls/dialer_test.go @@ -5,6 +5,8 @@ package ptls_test import ( + "bytes" + "context" "crypto/tls" "crypto/x509" "net/http/httptest" @@ -15,19 +17,11 @@ import ( "github.com/stretchr/testify/require" "go.pinniped.dev/internal/crypto/ptls" + "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/tlsserver" ) -type fakeerroronlylogger struct { -} - -func (_ *fakeerroronlylogger) Error(msg string, err error, keysAndValues ...any) { - // NOOP -} - -var _ ptls.ErrorOnlyLogger = (*fakeerroronlylogger)(nil) - func TestDialer(t *testing.T) { secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil) secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil) @@ -69,10 +63,14 @@ func TestDialer(t *testing.T) { t.Parallel() dialer := ptls.NewDialer() + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + err := dialer.IsReachableAndTLSValidationSucceeds( + context.Background(), urlToAddress(t, test.fullURL), test.certPool, - &fakeerroronlylogger{}, + logger, ) if test.wantError != "" { require.EqualError(t, err, test.wantError) @@ -86,51 +84,94 @@ func TestDialer(t *testing.T) { func TestDialer_TimeoutAfter15s(t *testing.T) { t.Parallel() - dialer := ptls.NewDialer() + dialTimeout := 15 * time.Second - timeout := time.After(30 * time.Second) - testDone := make(chan bool) + maxDurationForTest := 2 * dialTimeout + maxTimeForTest := time.After(maxDurationForTest) + testPassed := make(chan bool) go func() { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + defer cancel() + + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + dialer := ptls.NewDialer() err := dialer.IsReachableAndTLSValidationSucceeds( + ctx, // replace with context.Background() to verify that this hangs indefinitely setupHangingServer(t), nil, - &fakeerroronlylogger{}, + logger, ) require.EqualError(t, err, "context deadline exceeded") - testDone <- true + testPassed <- true }() select { - case <-timeout: - t.Errorf("test did not complete within 30 seconds") + case <-maxTimeForTest: + t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) t.FailNow() - case <-testDone: + case <-testPassed: t.Log("everything ok!") } } -func TestDialer_WithCustomTimeTimeoutAfter2s(t *testing.T) { +func TestDialer_WithoutDeadline_Uses30sTimeout(t *testing.T) { t.Parallel() - dialer := ptls.NewDialer().WithTimeout(2 * time.Second) - - timeout := time.After(5 * time.Second) - testDone := make(chan bool) + maxDurationForTest := 40 * time.Second + maxTimeForTest := time.After(maxDurationForTest) + testPassed := make(chan bool) go func() { + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + dialer := ptls.NewDialer() err := dialer.IsReachableAndTLSValidationSucceeds( + context.Background(), setupHangingServer(t), nil, - &fakeerroronlylogger{}, + logger, ) require.EqualError(t, err, "context deadline exceeded") - testDone <- true + testPassed <- true }() select { - case <-timeout: - t.Errorf("test did not complete within 5 seconds") + case <-maxTimeForTest: + t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) t.FailNow() - case <-testDone: + case <-testPassed: + t.Log("everything ok!") + } +} + +func TestDialer_WithNilContext_Uses30sTimeout(t *testing.T) { + t.Parallel() + + maxDurationForTest := 40 * time.Second + maxTimeForTest := time.After(maxDurationForTest) + testPassed := make(chan bool) + go func() { + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + dialer := ptls.NewDialer() + err := dialer.IsReachableAndTLSValidationSucceeds( + nil, + setupHangingServer(t), + nil, + logger, + ) + require.EqualError(t, err, "context deadline exceeded") + testPassed <- true + }() + + select { + case <-maxTimeForTest: + t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) + t.FailNow() + case <-testPassed: t.Log("everything ok!") } } From 0fab37c0898b8cfff3cbcce444375abe184883db Mon Sep 17 00:00:00 2001 From: Ashish Amarnath Date: Tue, 24 Sep 2024 10:52:23 -0700 Subject: [PATCH 3/4] Update internal/crypto/ptls/dialer_test.go ignore lint error on nil context in unit test validating nil context --- internal/crypto/ptls/dialer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go index 8bb15e917..3065ef377 100644 --- a/internal/crypto/ptls/dialer_test.go +++ b/internal/crypto/ptls/dialer_test.go @@ -158,7 +158,7 @@ func TestDialer_WithNilContext_Uses30sTimeout(t *testing.T) { dialer := ptls.NewDialer() err := dialer.IsReachableAndTLSValidationSucceeds( - nil, + nil, //nolint:staticcheck // Unit testing nil handling. setupHangingServer(t), nil, logger, From 01c2377de03e763bf64f160ed7064226e4b351bf Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Tue, 24 Sep 2024 14:41:46 -0500 Subject: [PATCH 4/4] Refactor tests to use a table --- internal/crypto/ptls/dialer_test.go | 176 ++++++++++++---------------- 1 file changed, 72 insertions(+), 104 deletions(-) diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go index 3065ef377..42ab985d1 100644 --- a/internal/crypto/ptls/dialer_test.go +++ b/internal/crypto/ptls/dialer_test.go @@ -81,115 +81,83 @@ func TestDialer(t *testing.T) { } } -func TestDialer_TimeoutAfter15s(t *testing.T) { - t.Parallel() +func TestDialer_AppliesTimeouts(t *testing.T) { + setupHangingServer := func(t *testing.T) string { + startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // This causes the dial to hang. I'm actually not quite sure why. + return nil, nil + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, startedTLSListener.Close()) + }) - dialTimeout := 15 * time.Second - - maxDurationForTest := 2 * dialTimeout - maxTimeForTest := time.After(maxDurationForTest) - testPassed := make(chan bool) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) - defer cancel() - - var log bytes.Buffer - logger := plog.TestLogger(t, &log) - - dialer := ptls.NewDialer() - err := dialer.IsReachableAndTLSValidationSucceeds( - ctx, // replace with context.Background() to verify that this hangs indefinitely - setupHangingServer(t), - nil, - logger, - ) - require.EqualError(t, err, "context deadline exceeded") - testPassed <- true - }() - - select { - case <-maxTimeForTest: - t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) - t.FailNow() - case <-testPassed: - t.Log("everything ok!") + return startedTLSListener.Addr().String() } -} -func TestDialer_WithoutDeadline_Uses30sTimeout(t *testing.T) { - t.Parallel() - - maxDurationForTest := 40 * time.Second - maxTimeForTest := time.After(maxDurationForTest) - testPassed := make(chan bool) - go func() { - var log bytes.Buffer - logger := plog.TestLogger(t, &log) - - dialer := ptls.NewDialer() - err := dialer.IsReachableAndTLSValidationSucceeds( - context.Background(), - setupHangingServer(t), - nil, - logger, - ) - require.EqualError(t, err, "context deadline exceeded") - testPassed <- true - }() - - select { - case <-maxTimeForTest: - t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) - t.FailNow() - case <-testPassed: - t.Log("everything ok!") - } -} - -func TestDialer_WithNilContext_Uses30sTimeout(t *testing.T) { - t.Parallel() - - maxDurationForTest := 40 * time.Second - maxTimeForTest := time.After(maxDurationForTest) - testPassed := make(chan bool) - go func() { - var log bytes.Buffer - logger := plog.TestLogger(t, &log) - - dialer := ptls.NewDialer() - err := dialer.IsReachableAndTLSValidationSucceeds( - nil, //nolint:staticcheck // Unit testing nil handling. - setupHangingServer(t), - nil, - logger, - ) - require.EqualError(t, err, "context deadline exceeded") - testPassed <- true - }() - - select { - case <-maxTimeForTest: - t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest) - t.FailNow() - case <-testPassed: - t.Log("everything ok!") - } -} - -func setupHangingServer(t *testing.T) string { - startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ - MinVersion: tls.VersionTLS12, - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - // This causes the dial to hang. I'm actually not quite sure why. - return nil, nil + tests := []struct { + name string + maxTestDuration time.Duration + makeContext func(*testing.T) context.Context + }{ + { + name: "timeout after 15s", + maxTestDuration: 20 * time.Second, + makeContext: func(t *testing.T) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + t.Cleanup(cancel) + return ctx + }, }, - }) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, startedTLSListener.Close()) - }) + { + name: "uses 30s timeout without a deadline", + maxTestDuration: 35 * time.Second, + makeContext: func(t *testing.T) context.Context { + return context.Background() + }, + }, + { + name: "uses 30s timeout with a nil context", + maxTestDuration: 35 * time.Second, + makeContext: func(t *testing.T) context.Context { + return nil + }, + }, + } - return startedTLSListener.Addr().String() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + maxTimeForTest := time.After(test.maxTestDuration) + testPassed := make(chan bool) + go func() { + var log bytes.Buffer + logger := plog.TestLogger(t, &log) + + dialer := ptls.NewDialer() + err := dialer.IsReachableAndTLSValidationSucceeds( + test.makeContext(t), + setupHangingServer(t), + nil, + logger, + ) + require.EqualError(t, err, "context deadline exceeded") + testPassed <- true + }() + + select { + case <-maxTimeForTest: + t.Errorf("timeout not honored: test did not complete within %s", test.maxTestDuration) + t.FailNow() + case <-testPassed: + t.Log("everything ok!") + } + }) + } } func urlToAddress(t *testing.T, urlAsString string) string {