diff --git a/internal/controller/tlsconfigutil/ca_bundle.go b/internal/controller/tlsconfigutil/ca_bundle.go index 48df83c96..b229fd0fd 100644 --- a/internal/controller/tlsconfigutil/ca_bundle.go +++ b/internal/controller/tlsconfigutil/ca_bundle.go @@ -15,12 +15,13 @@ type CABundle struct { certPool *x509.CertPool } -func NewCABundle(caBundle []byte, certPool *x509.CertPool) *CABundle { +func NewCABundle(caBundle []byte) (*CABundle, bool) { + certPool := x509.NewCertPool() return &CABundle{ caBundle: caBundle, sha256: sha256.Sum256(caBundle), certPool: certPool, - } + }, certPool.AppendCertsFromPEM(caBundle) } // GetCABundle returns the CA certificate bundle PEM bytes. diff --git a/internal/controller/tlsconfigutil/ca_bundle_test.go b/internal/controller/tlsconfigutil/ca_bundle_test.go index 942a01ef5..40022ddf6 100644 --- a/internal/controller/tlsconfigutil/ca_bundle_test.go +++ b/internal/controller/tlsconfigutil/ca_bundle_test.go @@ -1,6 +1,7 @@ package tlsconfigutil import ( + "crypto/sha256" "crypto/x509" "testing" "time" @@ -10,9 +11,34 @@ import ( "go.pinniped.dev/internal/certauthority" ) +func TestNewCABundle(t *testing.T) { + testCA, err := certauthority.New("Test CA", 1*time.Hour) + require.NoError(t, err) + + t.Run("generates the certPool and hash for certificate input", func(t *testing.T) { + caBundle, ok := NewCABundle(testCA.Bundle()) + require.True(t, ok) + + require.Equal(t, testCA.Bundle(), caBundle.GetCABundle()) + require.Equal(t, sha256.Sum256(testCA.Bundle()), caBundle.GetCABundleHash()) + require.Equal(t, string(testCA.Bundle()), caBundle.GetCABundlePemString()) + require.True(t, testCA.Pool().Equal(caBundle.GetCertPool()), "should be the cert pool of the testCA") + }) + + t.Run("returns false for non-certificate input", func(t *testing.T) { + caBundle, ok := NewCABundle([]byte("here are some bytes")) + require.False(t, ok) + + require.Equal(t, []byte("here are some bytes"), caBundle.GetCABundle()) + require.Equal(t, sha256.Sum256([]byte("here are some bytes")), caBundle.GetCABundleHash()) + require.Equal(t, "here are some bytes", caBundle.GetCABundlePemString()) + require.True(t, x509.NewCertPool().Equal(caBundle.GetCertPool()), "should be an empty cert pool") + }) +} + func TestGetCABundle(t *testing.T) { t.Run("returns the CA bundle", func(t *testing.T) { - caBundle := NewCABundle([]byte("here are some bytes"), nil) + caBundle, _ := NewCABundle([]byte("here are some bytes")) require.Equal(t, []byte("here are some bytes"), caBundle.GetCABundle()) }) @@ -26,7 +52,7 @@ func TestGetCABundle(t *testing.T) { func TestGetCABundlePemString(t *testing.T) { t.Run("returns the CA bundle PEM string", func(t *testing.T) { - caBundle := NewCABundle([]byte("here is a string"), nil) + caBundle, _ := NewCABundle([]byte("here is a string")) require.Equal(t, "here is a string", caBundle.GetCABundlePemString()) }) @@ -38,11 +64,10 @@ func TestGetCABundlePemString(t *testing.T) { } func TestGetCertPool(t *testing.T) { - t.Run("returns the cert pool", func(t *testing.T) { - aCertPool := x509.NewCertPool() - caBundle := NewCABundle(nil, aCertPool) + t.Run("returns the generated cert pool", func(t *testing.T) { + caBundle, _ := NewCABundle(nil) - require.Equal(t, aCertPool, caBundle.GetCertPool()) + require.Equal(t, x509.NewCertPool(), caBundle.GetCertPool()) }) t.Run("handles nil receiver by returning nil", func(t *testing.T) { @@ -62,7 +87,7 @@ func TestGetCABundleHash(t *testing.T) { sha256OfTest := [32]byte{159, 134, 208, 129, 136, 76, 125, 101, 154, 47, 234, 160, 197, 90, 208, 21, 163, 191, 79, 27, 43, 11, 130, 44, 209, 93, 108, 21, 176, 240, 10, 8} t.Run("returns the SHA256", func(t *testing.T) { - caBundle := NewCABundle([]byte("test"), nil) + caBundle, _ := NewCABundle([]byte("test")) require.Equal(t, sha256OfTest, caBundle.GetCABundleHash()) }) @@ -119,15 +144,27 @@ func TestCABundleIsEqual(t *testing.T) { expected: true, }, { - name: "should return equal when both left and right have same CA certificate bytes", - left: NewCABundle(testCA.Bundle(), certPool), - right: NewCABundle(testCA.Bundle(), certPool), + name: "should return equal when both left and right have same CA certificate bytes", + left: func() *CABundle { + caBundle, _ := NewCABundle(testCA.Bundle()) + return caBundle + }(), + right: func() *CABundle { + caBundle, _ := NewCABundle(testCA.Bundle()) + return caBundle + }(), expected: true, }, { - name: "should return not equal when both left and right do not have same CA certificate bytes", - left: NewCABundle(testCA.Bundle(), certPool), - right: NewCABundle([]byte("something that is not a cert"), certPool), + name: "should return not equal when both left and right do not have same CA certificate bytes", + left: func() *CABundle { + caBundle, _ := NewCABundle(testCA.Bundle()) + return caBundle + }(), + right: func() *CABundle { + caBundle, _ := NewCABundle([]byte("something that is not a cert")) + return caBundle + }(), expected: false, }, } diff --git a/internal/controller/tlsconfigutil/tls_config_util.go b/internal/controller/tlsconfigutil/tls_config_util.go index be327535e..275fc7114 100644 --- a/internal/controller/tlsconfigutil/tls_config_util.go +++ b/internal/controller/tlsconfigutil/tls_config_util.go @@ -4,7 +4,6 @@ package tlsconfigutil import ( - "crypto/x509" "encoding/base64" "fmt" "strings" @@ -184,14 +183,12 @@ func buildCABundle( return nil, nil } - // try to create a cert pool with the read ca data to determine validity of the ca bundle read from the tlsSpec. - certPool := x509.NewCertPool() - ok := certPool.AppendCertsFromPEM(caBundleAsBytes) + caBundle, ok := NewCABundle(caBundleAsBytes) if !ok { return nil, generateErrorForNoCertsInNonEmptyBundle() } - return NewCABundle(caBundleAsBytes, certPool), nil + return caBundle, nil } func readCABundleFromSource(source *caBundleSource, namespace string, secretInformer corev1informers.SecretInformer, configMapInformer corev1informers.ConfigMapInformer) (string, error) { diff --git a/internal/controller/tlsconfigutil/tls_config_util_test.go b/internal/controller/tlsconfigutil/tls_config_util_test.go index 1ffe8b729..4d528ec7f 100644 --- a/internal/controller/tlsconfigutil/tls_config_util_test.go +++ b/internal/controller/tlsconfigutil/tls_config_util_test.go @@ -5,7 +5,6 @@ package tlsconfigutil import ( "context" - "crypto/x509" "encoding/base64" "testing" "time" @@ -27,10 +26,11 @@ import ( func TestValidateTLSConfig(t *testing.T) { testCA, err := certauthority.New("Test CA", 1*time.Hour) require.NoError(t, err) - certPool := x509.NewCertPool() - require.True(t, certPool.AppendCertsFromPEM(testCA.Bundle())) base64EncodedBundle := base64.StdEncoding.EncodeToString(testCA.Bundle()) + testCABundle, ok := NewCABundle(testCA.Bundle()) + require.True(t, ok) + tests := []struct { name string tlsSpec *TLSSpec @@ -64,7 +64,7 @@ func TestValidateTLSConfig(t *testing.T) { tlsSpec: &TLSSpec{ CertificateAuthorityData: base64EncodedBundle, }, - expectedCABundle: NewCABundle(testCA.Bundle(), certPool), + expectedCABundle: testCABundle, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -135,7 +135,7 @@ func TestValidateTLSConfig(t *testing.T) { }, }, }, - expectedCABundle: NewCABundle(testCA.Bundle(), certPool), + expectedCABundle: testCABundle, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -165,7 +165,7 @@ func TestValidateTLSConfig(t *testing.T) { }, }, }, - expectedCABundle: NewCABundle(testCA.Bundle(), certPool), + expectedCABundle: testCABundle, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -394,7 +394,7 @@ func TestValidateTLSConfig(t *testing.T) { }, }, }, - expectedCABundle: NewCABundle(testCA.Bundle(), certPool), + expectedCABundle: testCABundle, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue,