From 76a116641fa2dc0da4a540bc190745939087856a Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Tue, 3 Sep 2024 14:45:14 -0500 Subject: [PATCH] Add ptls.Dialer to provide some common configuration for tls.Dial operations --- .../webhookcachefiller/webhookcachefiller.go | 15 +- .../webhookcachefiller_test.go | 9 +- .../github_upstream_watcher.go | 11 +- .../github_upstream_watcher_test.go | 108 +++++++----- .../controllermanager/prepare_controllers.go | 2 + internal/crypto/ptls/dialer.go | 58 +++++++ internal/crypto/ptls/dialer_test.go | 164 ++++++++++++++++++ internal/supervisor/server/server.go | 2 +- 8 files changed, 300 insertions(+), 69 deletions(-) create mode 100644 internal/crypto/ptls/dialer.go create mode 100644 internal/crypto/ptls/dialer_test.go diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go index eacddfc4c..5f7344171 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go @@ -6,7 +6,6 @@ package webhookcachefiller import ( "context" - "crypto/tls" "crypto/x509" "fmt" "net/url" @@ -77,6 +76,7 @@ func New( withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, log plog.Logger, + dialer ptls.Dialer, ) controllerlib.Controller { return controllerlib.New( controllerlib.Config{ @@ -90,6 +90,7 @@ func New( configMapInformer: configMapInformer, clock: clock, log: log.WithName(controllerName), + dialer: dialer, }, }, withInformer( @@ -125,6 +126,7 @@ type webhookCacheFillerController struct { client conciergeclientset.Interface clock clock.Clock log plog.Logger + dialer ptls.Dialer } // Sync implements controllerlib.Syncer. @@ -428,11 +430,11 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, nil } - conn, err := tls.Dial("tcp", endpointHostPort.Endpoint(), ptls.Default(certPool)) + err := c.dialer.IsReachableAndTLSValidationSucceeds(endpointHostPort.Endpoint(), certPool, logger) if err != nil { errText := "cannot dial server" - msg := fmt.Sprintf("%s: %s", errText, err.Error()) + msg := fmt.Sprintf("%s: %s", errText, err) conditions = append(conditions, &metav1.Condition{ Type: typeWebhookConnectionValid, Status: metav1.ConditionFalse, @@ -442,13 +444,6 @@ func (c *webhookCacheFillerController) validateConnection( return conditions, fmt.Errorf("%s: %w", errText, err) } - // this error should never be significant - err = conn.Close() - if err != nil { - // no unit test for this failure - logger.Error("error closing dialer", err) - } - conditions = append(conditions, successfulWebhookConnectionValidCondition()) return conditions, nil } diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go index 24b2b1c3a..e75e51607 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller_test.go @@ -1934,7 +1934,8 @@ func TestController(t *testing.T) { kubeInformers.Core().V1().ConfigMaps(), controllerlib.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2177,7 +2178,8 @@ func TestControllerFilterSecret(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.Secret{} filter := observableInformers.GetFilterForInformer(secretInformer) @@ -2238,7 +2240,8 @@ func TestControllerFilterConfigMap(t *testing.T) { configMapInformer, observableInformers.WithInformer, frozenClock, - logger) + logger, + ptls.NewDialer()) unrelated := &corev1.ConfigMap{} filter := observableInformers.GetFilterForInformer(configMapInformer) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index e90c1fc66..5326ec366 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "context" - "crypto/tls" "errors" "fmt" "net" @@ -117,7 +116,7 @@ type gitHubWatcherController struct { secretInformer corev1informers.SecretInformer configMapInformer corev1informers.ConfigMapInformer clock clock.Clock - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error) + dialer ptls.Dialer validatedCache GitHubValidatedAPICacheI } @@ -132,7 +131,7 @@ func New( log plog.Logger, withInformer pinnipedcontroller.WithInformerOptionFunc, clock clock.Clock, - dialFunc func(network, addr string, config *tls.Config) (*tls.Conn, error), + dialer ptls.Dialer, validatedCache *cache.Expiring, ) controllerlib.Controller { c := gitHubWatcherController{ @@ -144,7 +143,7 @@ func New( secretInformer: secretInformer, configMapInformer: configMapInformer, clock: clock, - dialFunc: dialFunc, + dialer: dialer, validatedCache: NewGitHubValidatedAPICache(validatedCache), } @@ -471,7 +470,7 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress := apiHostPort.Endpoint() if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) { - conn, tlsDialErr := c.dialFunc("tcp", apiAddress, ptls.Default(caBundle.CertPool())) + tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(apiAddress, caBundle.CertPool(), c.log) if tlsDialErr != nil { return &metav1.Condition{ Type: GitHubConnectionValid, @@ -481,8 +480,6 @@ func (c *gitHubWatcherController) validateGitHubConnection( apiAddress, *specifiedHost, buildDialErrorMessage(tlsDialErr)), }, nil, tlsDialErr } - // Any error should be ignored. We have performed a successful Dial, so no need to requeue this Sync. - _ = conn.Close() } c.validatedCache.MarkAsValidated(apiAddress, caBundle.Hash()) diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go index 63a2cd50b..d070f463e 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher_test.go @@ -6,7 +6,6 @@ package githubupstreamwatcher import ( "bytes" "context" - "crypto/tls" "crypto/x509" "encoding/base64" "fmt" @@ -41,6 +40,7 @@ import ( "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/tlsconfigutil" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/net/phttp" @@ -61,6 +61,32 @@ var ( githubIDPKind = idpv1alpha1.SchemeGroupVersion.WithKind("GitHubIdentityProvider") ) +type fakeGithubDialer struct { + t *testing.T + realAddress string + realCertPool *x509.CertPool +} + +func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(address string, _ *x509.CertPool, logger ptls.ErrorOnlyLogger) error { + require.Equal(f.t, "api.github.com:443", address) + + return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(f.realAddress, f.realCertPool, logger) +} + +var _ ptls.Dialer = (*fakeGithubDialer)(nil) + +type allowNoDials struct { + t *testing.T +} + +func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ string, _ *x509.CertPool, _ ptls.ErrorOnlyLogger) error { + f.t.Errorf("this test should not perform dial") + f.t.FailNow() + return nil +} + +var _ ptls.Dialer = (*allowNoDials)(nil) + func TestController(t *testing.T) { require.Equal(t, 6, countExpectedConditions) @@ -406,7 +432,7 @@ func TestController(t *testing.T) { name string githubIdentityProviders []runtime.Object secretsAndConfigMaps []runtime.Object - mockDialer func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) + mockDialer func(*testing.T) ptls.Dialer preexistingValidatedCache []GitHubValidatedAPICacheKey wantErr string wantLogs []string @@ -555,15 +581,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -638,15 +662,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -721,15 +743,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -804,15 +824,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -887,15 +905,13 @@ func TestController(t *testing.T) { return githubIDP }(), }, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - require.Equal(t, "api.github.com:443", addr) - // don't actually dial github.com to avoid making external network calls in unit test - configClone := config.Clone() - configClone.RootCAs = goodServerCertPool - return tls.Dial(network, goodServerDomain, configClone) + return &fakeGithubDialer{ + t: t, + realAddress: goodServerDomain, + realCertPool: goodServerCertPool, } }, wantResultingCache: []*upstreamgithub.ProviderConfig{ @@ -1379,14 +1395,10 @@ func TestController(t *testing.T) { name: "happy path with previously validated address/CA Bundle does not validate again", secretsAndConfigMaps: []runtime.Object{goodClientCredentialsSecret}, githubIdentityProviders: []runtime.Object{validFilledOutIDP}, - mockDialer: func(t *testing.T) func(network, addr string, config *tls.Config) (*tls.Conn, error) { + mockDialer: func(t *testing.T) ptls.Dialer { t.Helper() - return func(network, addr string, config *tls.Config) (*tls.Conn, error) { - t.Errorf("this test should not perform dial") - t.FailNow() - return nil, nil - } + return &allowNoDials{t: t} }, preexistingValidatedCache: []GitHubValidatedAPICacheKey{ { @@ -2479,7 +2491,7 @@ func TestController(t *testing.T) { gitHubIdentityProviderInformer := supervisorInformers.IDP().V1alpha1().GitHubIdentityProviders() - dialer := tls.Dial + var dialer ptls.Dialer = ptls.NewDialer() if tt.mockDialer != nil { dialer = tt.mockDialer(t) } @@ -2882,7 +2894,7 @@ func TestController_OnlyWantActions(t *testing.T) { logger, controllerlib.WithInformer, frozenClockForLastTransitionTime, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3006,7 +3018,7 @@ func TestGitHubUpstreamWatcherControllerFilterSecret(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3063,7 +3075,7 @@ func TestGitHubUpstreamWatcherControllerFilterConfigMaps(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) @@ -3120,7 +3132,7 @@ func TestGitHubUpstreamWatcherControllerFilterGitHubIDP(t *testing.T) { logger, observableInformers.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ) diff --git a/internal/controllermanager/prepare_controllers.go b/internal/controllermanager/prepare_controllers.go index 6d5144216..0dfc2396a 100644 --- a/internal/controllermanager/prepare_controllers.go +++ b/internal/controllermanager/prepare_controllers.go @@ -28,6 +28,7 @@ import ( "go.pinniped.dev/internal/controller/serviceaccounttokencleanup" "go.pinniped.dev/internal/controllerinit" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/deploymentref" "go.pinniped.dev/internal/downward" "go.pinniped.dev/internal/dynamiccert" @@ -244,6 +245,7 @@ func PrepareControllers(c *Config) (controllerinit.RunnerBuilder, error) { //nol controllerlib.WithInformer, clock.RealClock{}, plog.New(), + ptls.NewDialer(), ), singletonWorker, ). diff --git a/internal/crypto/ptls/dialer.go b/internal/crypto/ptls/dialer.go new file mode 100644 index 000000000..ba0309338 --- /dev/null +++ b/internal/crypto/ptls/dialer.go @@ -0,0 +1,58 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package ptls + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" +) + +type Dialer interface { + IsReachableAndTLSValidationSucceeds( + address string, + certPool *x509.CertPool, + logger ErrorOnlyLogger, + ) error +} + +type ErrorOnlyLogger interface { + Error(msg string, err error, keysAndValues ...any) +} + +type internalDialer struct { + dialer *net.Dialer +} + +func NewDialer() *internalDialer { + return &internalDialer{ + dialer: &net.Dialer{ + Timeout: 15 * time.Second, + }, + } +} + +func (i *internalDialer) WithTimeout(timeout time.Duration) Dialer { + i.dialer.Timeout = timeout + return i +} + +func (i *internalDialer) IsReachableAndTLSValidationSucceeds( + address string, + certPool *x509.CertPool, + logger ErrorOnlyLogger, +) error { + connection, err := tls.DialWithDialer(i.dialer, "tcp", address, Default(certPool)) + if err != nil { + // Don't wrap this error message since this is just a helper function. + return err + } + err = connection.Close() + if err != nil { // untested + // Log it just so that it doesn't completely disappear. + logger.Error("Failed to close connection: ", err) + } + return nil +} diff --git a/internal/crypto/ptls/dialer_test.go b/internal/crypto/ptls/dialer_test.go new file mode 100644 index 000000000..f40de7a6c --- /dev/null +++ b/internal/crypto/ptls/dialer_test.go @@ -0,0 +1,164 @@ +// Copyright 2024 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Use this package to avoid import loops with internal/testutil/tlsserver +package ptls_test + +import ( + "crypto/tls" + "crypto/x509" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/crypto/ptls" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/tlsserver" +) + +type fakeerroronlylogger struct { +} + +func (_ *fakeerroronlylogger) Error(msg string, err error, keysAndValues ...any) { + // NOOP +} + +var _ ptls.ErrorOnlyLogger = (*fakeerroronlylogger)(nil) + +func TestDialer(t *testing.T) { + secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil) + secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil) + insecureServer := httptest.NewServer(nil) + + fakeCert, _, err := testutil.CreateCertificate(time.Now().Add(-1*time.Hour), time.Now().Add(time.Hour)) + require.NoError(t, err) + + tests := []struct { + name string + fullURL string + certPool *x509.CertPool + wantError string + }{ + { + name: "happy path with TLS-enabled IPv4", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(secureServerIPv4CA), + }, + { + name: "happy path with TLS-enabled IPv6", + fullURL: secureServerIPv6.URL, + certPool: bytesToCertPool(secureServerIPv6CA), + }, + { + name: "returns error when connecting to a non-TLS server", + fullURL: insecureServer.URL, + wantError: "tls: first record does not look like a TLS handshake", + }, + { + name: "returns error when using the wrong bundle", + fullURL: secureServerIPv4.URL, + certPool: bytesToCertPool(fakeCert), + wantError: "tls: failed to verify certificate: x509: certificate signed by unknown authority", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + dialer := ptls.NewDialer() + + err := dialer.IsReachableAndTLSValidationSucceeds( + urlToAddress(t, test.fullURL), + test.certPool, + &fakeerroronlylogger{}, + ) + if test.wantError != "" { + require.EqualError(t, err, test.wantError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDialer_TimeoutAfter15s(t *testing.T) { + t.Parallel() + + dialer := ptls.NewDialer() + + timeout := time.After(30 * time.Second) + testDone := make(chan bool) + go func() { + err := dialer.IsReachableAndTLSValidationSucceeds( + setupHangingServer(t), + nil, + &fakeerroronlylogger{}, + ) + require.EqualError(t, err, "context deadline exceeded") + testDone <- true + }() + + select { + case <-timeout: + t.Errorf("test did not complete within 30 seconds") + t.FailNow() + case <-testDone: + t.Log("everything ok!") + } +} + +func TestDialer_WithCustomTimeTimeoutAfter2s(t *testing.T) { + t.Parallel() + + dialer := ptls.NewDialer().WithTimeout(2 * time.Second) + + timeout := time.After(5 * time.Second) + testDone := make(chan bool) + go func() { + err := dialer.IsReachableAndTLSValidationSucceeds( + setupHangingServer(t), + nil, + &fakeerroronlylogger{}, + ) + require.EqualError(t, err, "context deadline exceeded") + testDone <- true + }() + + select { + case <-timeout: + t.Errorf("test did not complete within 5 seconds") + t.FailNow() + case <-testDone: + t.Log("everything ok!") + } +} + +func setupHangingServer(t *testing.T) string { + startedTLSListener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // This causes the dial to hang. I'm actually not quite sure why. + return nil, nil + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, startedTLSListener.Close()) + }) + + return startedTLSListener.Addr().String() +} + +func urlToAddress(t *testing.T, urlAsString string) string { + u, err := url.Parse(urlAsString) + require.NoError(t, err) + return u.Host +} + +func bytesToCertPool(ca []byte) *x509.CertPool { + x509CertPool := x509.NewCertPool() + x509CertPool.AppendCertsFromPEM(ca) + return x509CertPool +} diff --git a/internal/supervisor/server/server.go b/internal/supervisor/server/server.go index 5aa4270af..ae2b65167 100644 --- a/internal/supervisor/server/server.go +++ b/internal/supervisor/server/server.go @@ -342,7 +342,7 @@ func prepareControllers( plog.New(), controllerlib.WithInformer, clock.RealClock{}, - tls.Dial, + ptls.NewDialer(), cache.NewExpiring(), ), singletonWorker).