Refactor provider_test to not use pool.Subjects()

This commit is contained in:
Joshua Casey
2024-08-19 12:19:52 -05:00
parent 0ee8ee80e1
commit a1dafcf45a

View File

@@ -28,54 +28,61 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
name string name string
f func(t *testing.T, ca Provider, certKey Private) (wantClientCASubjects [][]byte, wantCerts []tls.Certificate) buildCertPool func(t *testing.T, ca Provider) *x509.CertPool
buildServingCerts func(t *testing.T, certKey Private) []tls.Certificate
}{ }{
{ {
name: "no-op leave everything alone", name: "no-op leave everything alone",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
pool := x509.NewCertPool() pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle") require.True(t, ok, "should have valid non-empty CA bundle")
certPEM, keyPEM := certKey.CurrentCertKeyContent() return pool
cert, err := tls.X509KeyPair(certPEM, keyPEM) },
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent())
require.NoError(t, err) require.NoError(t, err)
return pool.Subjects(), []tls.Certificate{cert} return []tls.Certificate{cert}
}, },
}, },
{ {
name: "unset the CA", name: "unset the CA",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
ca.UnsetCertKeyContent() ca.UnsetCertKeyContent()
certPEM, keyPEM := certKey.CurrentCertKeyContent() return nil
cert, err := tls.X509KeyPair(certPEM, keyPEM) },
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent())
require.NoError(t, err) require.NoError(t, err)
return nil, []tls.Certificate{cert} return []tls.Certificate{cert}
}, },
}, },
{ {
name: "unset the serving cert - still serves the old content", name: "unset the serving cert - still serves the old content",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
pool := x509.NewCertPool() pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle") require.True(t, ok, "should have valid non-empty CA bundle")
certPEM, keyPEM := certKey.CurrentCertKeyContent() return pool
cert, err := tls.X509KeyPair(certPEM, keyPEM) },
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent())
require.NoError(t, err) require.NoError(t, err)
certKey.UnsetCertKeyContent() certKey.UnsetCertKeyContent()
return pool.Subjects(), []tls.Certificate{cert} return []tls.Certificate{cert}
}, },
}, },
{ {
name: "change to a new CA", name: "change to a new CA",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
// use unique names for all CAs to make sure the pool subjects are different // use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err) require.NoError(t, err)
@@ -84,16 +91,25 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
err = ca.SetCertKeyContent(newCA.Bundle(), caKey) err = ca.SetCertKeyContent(newCA.Bundle(), caKey)
require.NoError(t, err) require.NoError(t, err)
certPEM, keyPEM := certKey.CurrentCertKeyContent() return newCA.Pool()
cert, err := tls.X509KeyPair(certPEM, keyPEM) },
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
cert, err := tls.X509KeyPair(certKey.CurrentCertKeyContent())
require.NoError(t, err) require.NoError(t, err)
return newCA.Pool().Subjects(), []tls.Certificate{cert} return []tls.Certificate{cert}
}, },
}, },
{ {
name: "change to new serving cert", name: "change to new serving cert",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle")
return pool
},
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
// use unique names for all CAs to make sure the pool subjects are different // use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err) require.NoError(t, err)
@@ -107,16 +123,23 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
cert, err := tls.X509KeyPair(certPEM, keyPEM) cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err) require.NoError(t, err)
pool := x509.NewCertPool() return []tls.Certificate{cert}
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle")
return pool.Subjects(), []tls.Certificate{cert}
}, },
}, },
{ {
name: "change both CA and serving cert", name: "change both CA and serving cert",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { buildCertPool: func(t *testing.T, ca Provider) *x509.CertPool {
// use unique names for all CAs to make sure the pool subjects are different
newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour)
require.NoError(t, err)
caKey, err := newOtherCA.PrivateKeyToPEM()
require.NoError(t, err)
err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey)
require.NoError(t, err)
return newOtherCA.Pool()
},
buildServingCerts: func(t *testing.T, certKey Private) []tls.Certificate {
// use unique names for all CAs to make sure the pool subjects are different // use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err) require.NoError(t, err)
@@ -130,15 +153,7 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
cert, err := tls.X509KeyPair(certPEM, keyPEM) cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err) require.NoError(t, err)
// use unique names for all CAs to make sure the pool subjects are different return []tls.Certificate{cert}
newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour)
require.NoError(t, err)
caKey, err := newOtherCA.PrivateKeyToPEM()
require.NoError(t, err)
err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey)
require.NoError(t, err)
return newOtherCA.Pool().Subjects(), []tls.Certificate{cert}
}, },
}, },
} }
@@ -184,7 +199,8 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
tlsConfig.GetConfigForClient = dynamicCertificateController.GetConfigForClient tlsConfig.GetConfigForClient = dynamicCertificateController.GetConfigForClient
wantClientCASubjects, wantCerts := tt.f(t, caContent, certKeyContent) wantClientPool := tt.buildCertPool(t, caContent)
wantServingCerts := tt.buildServingCerts(t, certKeyContent)
var lastTLSConfig *tls.Config var lastTLSConfig *tls.Config
@@ -197,18 +213,14 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
lastTLSConfig = actualTLSConfig lastTLSConfig = actualTLSConfig
return reflect.DeepEqual(wantClientCASubjects, poolSubjects(actualTLSConfig.ClientCAs)) && return wantClientPool.Equal(actualTLSConfig.ClientCAs) &&
reflect.DeepEqual(wantCerts, actualTLSConfig.Certificates), nil reflect.DeepEqual(wantServingCerts, actualTLSConfig.Certificates), nil
}) })
if err != nil && lastTLSConfig != nil { if err != nil && lastTLSConfig != nil {
// for debugging failures // for debugging failures
t.Log("diff between client CAs:\n", cmp.Diff(
testlib.Sdump(wantClientCASubjects),
testlib.Sdump(poolSubjects(lastTLSConfig.ClientCAs)),
))
t.Log("diff between serving certs:\n", cmp.Diff( t.Log("diff between serving certs:\n", cmp.Diff(
testlib.Sdump(wantCerts), testlib.Sdump(wantServingCerts),
testlib.Sdump(lastTLSConfig.Certificates), testlib.Sdump(lastTLSConfig.Certificates),
)) ))
} }
@@ -217,13 +229,6 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) {
} }
} }
func poolSubjects(pool *x509.CertPool) [][]byte {
if pool == nil {
return nil
}
return pool.Subjects()
}
func TestNewServingCert(t *testing.T) { func TestNewServingCert(t *testing.T) {
got := NewServingCert("") got := NewServingCert("")