diff --git a/internal/controller/tlsconfigutil/ca_bundle.go b/internal/controller/tlsconfigutil/ca_bundle.go index 7b53a697f..7ac7dd412 100644 --- a/internal/controller/tlsconfigutil/ca_bundle.go +++ b/internal/controller/tlsconfigutil/ca_bundle.go @@ -16,12 +16,19 @@ type CABundle struct { } func NewCABundle(caBundle []byte) (*CABundle, bool) { - certPool := x509.NewCertPool() + var certPool *x509.CertPool + ok := true + + if len(caBundle) > 0 { + certPool = x509.NewCertPool() + ok = certPool.AppendCertsFromPEM(caBundle) + } + return &CABundle{ caBundle: caBundle, sha256: sha256.Sum256(caBundle), certPool: certPool, - }, certPool.AppendCertsFromPEM(caBundle) + }, ok } // PEMBytes returns the CA certificate bundle PEM bytes. @@ -59,8 +66,3 @@ func (c *CABundle) Hash() [32]byte { } return c.sha256 // note that this will always return the same hash for nil input } - -// IsEqual returns whether a CABundle has the same CA certificate bundle as another. -func (c *CABundle) IsEqual(other *CABundle) bool { - return c.Hash() == other.Hash() -} diff --git a/internal/controller/tlsconfigutil/ca_bundle_test.go b/internal/controller/tlsconfigutil/ca_bundle_test.go index c11e21e09..d49f9b156 100644 --- a/internal/controller/tlsconfigutil/ca_bundle_test.go +++ b/internal/controller/tlsconfigutil/ca_bundle_test.go @@ -43,10 +43,19 @@ func TestPEMBytes(t *testing.T) { require.Equal(t, []byte("here are some bytes"), caBundle.PEMBytes()) }) + t.Run("handles nil bundle by returning nil", func(t *testing.T) { + caBundle, _ := NewCABundle(nil) + require.Nil(t, caBundle.PEMBytes()) + }) + + t.Run("handles empty bundle by returning empty byte array", func(t *testing.T) { + caBundle, _ := NewCABundle([]byte{}) + require.Equal(t, []byte{}, caBundle.PEMBytes()) + }) + t.Run("handles nil receiver by returning nil", func(t *testing.T) { var nilCABundle *CABundle - var expected []byte - require.Equal(t, expected, nilCABundle.PEMBytes()) + require.Nil(t, nilCABundle.PEMBytes()) }) } @@ -57,23 +66,49 @@ func TestPEMString(t *testing.T) { require.Equal(t, "here is a string", caBundle.PEMString()) }) - t.Run("handles nil receiver by returning empty sstring", func(t *testing.T) { + t.Run("handles nil bundle by returning empty string", func(t *testing.T) { + caBundle, _ := NewCABundle(nil) + + require.Equal(t, "", caBundle.PEMString()) + }) + + t.Run("handles empty bundle by returning empty string", func(t *testing.T) { + caBundle, _ := NewCABundle([]byte{}) + + require.Equal(t, "", caBundle.PEMString()) + }) + + t.Run("handles nil receiver by returning empty string", func(t *testing.T) { var nilCABundle *CABundle - require.Equal(t, "", nilCABundle.PEMString()) + require.Empty(t, nilCABundle.PEMString()) }) } func TestCertPool(t *testing.T) { - t.Run("returns the generated cert pool", func(t *testing.T) { + t.Run("returns the certPool when the caBundle is valid", func(t *testing.T) { + testCA, err := certauthority.New("Test CA", 1*time.Hour) + require.NoError(t, err) + + caBundle, _ := NewCABundle(testCA.Bundle()) + + require.True(t, testCA.Pool().Equal(caBundle.CertPool())) + }) + + t.Run("returns a nil certPool when the caBundle is nil", func(t *testing.T) { caBundle, _ := NewCABundle(nil) - require.Equal(t, x509.NewCertPool(), caBundle.CertPool()) + require.Nil(t, caBundle.CertPool()) + }) + + t.Run("returns a nil certPool when the caBundle is empty", func(t *testing.T) { + caBundle, _ := NewCABundle([]byte{}) + + require.Nil(t, caBundle.CertPool()) }) t.Run("handles nil receiver by returning nil", func(t *testing.T) { var nilCABundle *CABundle - var expected *x509.CertPool - require.Equal(t, expected, nilCABundle.CertPool()) + require.Nil(t, nilCABundle.CertPool()) }) } @@ -92,6 +127,18 @@ func TestHash(t *testing.T) { require.Equal(t, sha256OfTest, caBundle.Hash()) }) + t.Run("returns the SHA256 when the PEM is nil", func(t *testing.T) { + caBundle, _ := NewCABundle(nil) + + require.Equal(t, sha256OfNil, caBundle.Hash()) + }) + + t.Run("returns the SHA256 when the PEM is empty", func(t *testing.T) { + caBundle, _ := NewCABundle([]byte{}) + + require.Equal(t, sha256OfNil, caBundle.Hash()) + }) + t.Run("handles nil receiver by returning the hash of nil", func(t *testing.T) { var nilCABundle *CABundle @@ -112,68 +159,3 @@ func TestHash(t *testing.T) { require.Equal(t, sha256OfTest, caBundle.Hash()) }) } - -func TestCABundleIsEqual(t *testing.T) { - testCA, err := certauthority.New("Test CA", 1*time.Hour) - require.NoError(t, err) - certPool := x509.NewCertPool() - require.True(t, certPool.AppendCertsFromPEM(testCA.Bundle())) - - tests := []struct { - name string - left *CABundle - right *CABundle - expected bool - }{ - { - name: "should return equal when left and right are nil", - left: nil, - right: nil, - expected: true, - }, - { - name: "should return equal when left is nil and right is empty", - left: nil, - right: &CABundle{}, - expected: true, - }, - { - name: "should return equal when right is nil and left is empty", - left: &CABundle{}, - right: nil, - expected: true, - }, - { - name: "should return equal when both left and right have same CA certificate bytes", - left: func() *CABundle { - caBundle, _ := NewCABundle(testCA.Bundle()) - return caBundle - }(), - right: func() *CABundle { - caBundle, _ := NewCABundle(testCA.Bundle()) - return caBundle - }(), - expected: true, - }, - { - name: "should return not equal when both left and right do not have same CA certificate bytes", - left: func() *CABundle { - caBundle, _ := NewCABundle(testCA.Bundle()) - return caBundle - }(), - right: func() *CABundle { - caBundle, _ := NewCABundle([]byte("something that is not a cert")) - return caBundle - }(), - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - require.Equal(t, tt.expected, tt.left.IsEqual(tt.right)) - require.Equal(t, tt.expected, tt.right.IsEqual(tt.left)) - }) - } -} diff --git a/internal/controller/tlsconfigutil/tls_config_util_test.go b/internal/controller/tlsconfigutil/tls_config_util_test.go index 4d528ec7f..0fae9f05c 100644 --- a/internal/controller/tlsconfigutil/tls_config_util_test.go +++ b/internal/controller/tlsconfigutil/tls_config_util_test.go @@ -497,7 +497,9 @@ func TestValidateTLSConfig(t *testing.T) { require.Equal(t, tt.expectedCondition, actualCondition) if tt.expectedCABundle != nil { - require.True(t, tt.expectedCABundle.IsEqual(actualBundle), "expectedCertPool did not equal actualCertPool") + require.Equal(t, tt.expectedCABundle.Hash(), actualBundle.Hash()) + require.Equal(t, tt.expectedCABundle.PEMBytes(), actualBundle.PEMBytes()) + require.True(t, tt.expectedCABundle.CertPool().Equal(actualBundle.CertPool())) } }) }