From a1dafcf45a82201a60dd55011eba6785fffa4700 Mon Sep 17 00:00:00 2001 From: Joshua Casey Date: Mon, 19 Aug 2024 12:19:52 -0500 Subject: [PATCH] Refactor provider_test to not use pool.Subjects() --- internal/dynamiccert/provider_test.go | 103 ++++++++++++++------------ 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/internal/dynamiccert/provider_test.go b/internal/dynamiccert/provider_test.go index f4e9c3fd9..385adc364 100644 --- a/internal/dynamiccert/provider_test.go +++ b/internal/dynamiccert/provider_test.go @@ -28,54 +28,61 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { t.Parallel() tests := []struct { - name string - f func(t *testing.T, ca Provider, certKey Private) (wantClientCASubjects [][]byte, wantCerts []tls.Certificate) + name string + buildCertPool func(t *testing.T, ca Provider) *x509.CertPool + buildServingCerts func(t *testing.T, certKey Private) []tls.Certificate }{ { name: "no-op leave everything alone", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) require.True(t, ok, "should have valid non-empty CA bundle") - certPEM, keyPEM := certKey.CurrentCertKeyContent() - cert, err := tls.X509KeyPair(certPEM, keyPEM) + return pool + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { + cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent()) require.NoError(t, err) - return pool.Subjects(), []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, { name: "unset the CA", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { ca.UnsetCertKeyContent() - certPEM, keyPEM := certKey.CurrentCertKeyContent() - cert, err := tls.X509KeyPair(certPEM, keyPEM) + return nil + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { + cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent()) require.NoError(t, err) - return nil, []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, { name: "unset the serving cert - still serves the old content", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) require.True(t, ok, "should have valid non-empty CA bundle") - certPEM, keyPEM := certKey.CurrentCertKeyContent() - cert, err := tls.X509KeyPair(certPEM, keyPEM) + return pool + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { + cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent()) require.NoError(t, err) certKey.UnsetCertKeyContent() - return pool.Subjects(), []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, { name: "change to a new CA", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) @@ -84,16 +91,25 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { err = ca.SetCertKeyContent(newCA.Bundle(), caKey) require.NoError(t, err) - certPEM, keyPEM := certKey.CurrentCertKeyContent() - cert, err := tls.X509KeyPair(certPEM, keyPEM) + return newCA.Pool() + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { + cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent()) require.NoError(t, err) - return newCA.Pool().Subjects(), []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, { name: "change to new serving cert", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) + require.True(t, ok, "should have valid non-empty CA bundle") + + return pool + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) @@ -107,16 +123,23 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) - pool := x509.NewCertPool() - ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) - require.True(t, ok, "should have valid non-empty CA bundle") - - return pool.Subjects(), []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, { name: "change both CA and serving cert", - f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { + buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool { + // use unique names for all CAs to make sure the pool subjects are different + newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour) + require.NoError(t, err) + caKey, err := newOtherCA.PrivateKeyToPEM() + require.NoError(t, err) + err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey) + require.NoError(t, err) + + return newOtherCA.Pool() + }, + buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) @@ -130,15 +153,7 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) - // use unique names for all CAs to make sure the pool subjects are different - newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour) - require.NoError(t, err) - caKey, err := newOtherCA.PrivateKeyToPEM() - require.NoError(t, err) - err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey) - require.NoError(t, err) - - return newOtherCA.Pool().Subjects(), []tls.Certificate{cert} + return []tls.Certificate{cert} }, }, } @@ -184,7 +199,8 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { tlsConfig.GetConfigForClient = dynamicCertificateController.GetConfigForClient - wantClientCASubjects, wantCerts := tt.f(t, caContent, certKeyContent) + wantClientPool := tt.buildCertPool(t, caContent) + wantServingCerts := tt.buildServingCerts(t, certKeyContent) var lastTLSConfig *tls.Config @@ -197,18 +213,14 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { lastTLSConfig = actualTLSConfig - return reflect.DeepEqual(wantClientCASubjects, poolSubjects(actualTLSConfig.ClientCAs)) && - reflect.DeepEqual(wantCerts, actualTLSConfig.Certificates), nil + return wantClientPool.Equal(actualTLSConfig.ClientCAs) && + reflect.DeepEqual(wantServingCerts, actualTLSConfig.Certificates), nil }) if err != nil && lastTLSConfig != nil { // for debugging failures - t.Log("diff between client CAs:\n", cmp.Diff( - testlib.Sdump(wantClientCASubjects), - testlib.Sdump(poolSubjects(lastTLSConfig.ClientCAs)), - )) t.Log("diff between serving certs:\n", cmp.Diff( - testlib.Sdump(wantCerts), + testlib.Sdump(wantServingCerts), testlib.Sdump(lastTLSConfig.Certificates), )) } @@ -217,13 +229,6 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { } } -func poolSubjects(pool *x509.CertPool) [][]byte { - if pool == nil { - return nil - } - return pool.Subjects() -} - func TestNewServingCert(t *testing.T) { got := NewServingCert("")