From 005dbf3aa86990fe57e5a14d2a60af837ecd35ba Mon Sep 17 00:00:00 2001 From: Ashish Amarnath Date: Thu, 25 Jul 2024 19:20:57 -0700 Subject: [PATCH] refactor tlsconfigutil to return a caBundle type Signed-off-by: Ashish Amarnath --- .../jwtcachefiller/jwtcachefiller.go | 14 +-- .../webhookcachefiller/webhookcachefiller.go | 4 +- .../github_upstream_watcher.go | 6 +- .../oidc_upstream_watcher.go | 7 +- .../upstreamwatchers/upstream_watchers.go | 4 +- .../tlsconfigutil/tls_config_util.go | 50 +++++++- .../tlsconfigutil/tls_config_util_test.go | 118 ++++++++++++++---- 7 files changed, 155 insertions(+), 48 deletions(-) diff --git a/internal/controller/authenticator/jwtcachefiller/jwtcachefiller.go b/internal/controller/authenticator/jwtcachefiller/jwtcachefiller.go index ee11a13e6..2db9bea1e 100644 --- a/internal/controller/authenticator/jwtcachefiller/jwtcachefiller.go +++ b/internal/controller/authenticator/jwtcachefiller/jwtcachefiller.go @@ -7,8 +7,6 @@ package jwtcachefiller import ( "context" - "crypto/sha256" - "crypto/x509" "errors" "fmt" "net/http" @@ -223,8 +221,8 @@ func (c *jwtCacheFillerController) syncIndividualJWTAuthenticator(ctx context.Co } conditions := make([]*metav1.Condition, 0) - certPool, caBundlePEM, conditions, tlsBundleOk := c.validateTLSBundle(jwtAuthenticator.Spec.TLS, conditions) - caBundlePEMSHA256 := sha256.Sum256(caBundlePEM) // note that this will always return the same hash for nil input + caBundle, conditions, tlsBundleOk := c.validateTLSBundle(jwtAuthenticator.Spec.TLS, conditions) + caBundlePEMSHA256 := caBundle.GetCABundleHash() // Only revalidate and update the cache if the cached authenticator is different from the desired authenticator. // There is no need to repeat validations for a spec that was already successfully validated. We are making a @@ -252,7 +250,7 @@ func (c *jwtCacheFillerController) syncIndividualJWTAuthenticator(ctx context.Co _, conditions, issuerOk := c.validateIssuer(jwtAuthenticator.Spec.Issuer, conditions) okSoFar := tlsBundleOk && issuerOk - client := phttp.Default(certPool) + client := phttp.Default(caBundle.GetCertPool()) client.Timeout = 30 * time.Second // copied from Kube OIDC code coreOSCtx := coreosoidc.ClientContext(context.Background(), client) @@ -317,8 +315,8 @@ func (c *jwtCacheFillerController) cacheValueAsJWTAuthenticator(value authncache return jwtAuthenticator } -func (c *jwtCacheFillerController) validateTLSBundle(tlsSpec *authenticationv1alpha1.TLSSpec, conditions []*metav1.Condition) (*x509.CertPool, []byte, []*metav1.Condition, bool) { - condition, pemBundle, certPool := tlsconfigutil.ValidateTLSConfig( +func (c *jwtCacheFillerController) validateTLSBundle(tlsSpec *authenticationv1alpha1.TLSSpec, conditions []*metav1.Condition) (*tlsconfigutil.CABundle, []*metav1.Condition, bool) { + condition, caBundle := tlsconfigutil.ValidateTLSConfig( tlsconfigutil.TLSSpecForConcierge(tlsSpec), "spec.tls", c.namespace, @@ -326,7 +324,7 @@ func (c *jwtCacheFillerController) validateTLSBundle(tlsSpec *authenticationv1al c.configMapInformer) conditions = append(conditions, condition) - return certPool, pemBundle, conditions, condition.Status == metav1.ConditionTrue + return caBundle, conditions, condition.Status == metav1.ConditionTrue } func (c *jwtCacheFillerController) validateIssuer(issuer string, conditions []*metav1.Condition) (*url.URL, []*metav1.Condition, bool) { diff --git a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go index fa47f1c74..9c2eddd59 100644 --- a/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go +++ b/internal/controller/authenticator/webhookcachefiller/webhookcachefiller.go @@ -244,7 +244,7 @@ func (c *webhookCacheFillerController) cacheValueAsWebhookAuthenticator(value au } func (c *webhookCacheFillerController) validateTLSBundle(tlsSpec *authenticationv1alpha1.TLSSpec, conditions []*metav1.Condition) (*x509.CertPool, []byte, []*metav1.Condition, bool) { - condition, pemBundle, certPool := tlsconfigutil.ValidateTLSConfig( + condition, caBundle := tlsconfigutil.ValidateTLSConfig( tlsconfigutil.TLSSpecForConcierge(tlsSpec), "spec.tls", c.namespace, @@ -252,7 +252,7 @@ func (c *webhookCacheFillerController) validateTLSBundle(tlsSpec *authentication c.configMapInformer) conditions = append(conditions, condition) - return certPool, pemBundle, conditions, condition.Status == metav1.ConditionTrue + return caBundle.GetCertPool(), caBundle.GetCABundle(), conditions, condition.Status == metav1.ConditionTrue } // newWebhookAuthenticator creates a webhook from the provided API server url and caBundle diff --git a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go index 7d8da94be..db1301376 100644 --- a/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go +++ b/internal/controller/supervisorconfig/githubupstreamwatcher/github_upstream_watcher.go @@ -325,7 +325,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro hostCondition, hostPort := validateHost(upstream.Spec.GitHubAPI) conditions = append(conditions, hostCondition) - tlsConfigCondition, caBundlePEM, certPool := tlsconfigutil.ValidateTLSConfig( + tlsConfigCondition, caBundle := tlsconfigutil.ValidateTLSConfig( tlsconfigutil.TLSSpecForSupervisor(upstream.Spec.GitHubAPI.TLS), "spec.githubAPI.tls", c.namespace, @@ -335,8 +335,8 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro githubConnectionCondition, hostURL, httpClient, githubConnectionErr := c.validateGitHubConnection( hostPort, - caBundlePEM, - certPool, + caBundle.GetCABundle(), + caBundle.GetCertPool(), hostCondition.Status == metav1.ConditionTrue, tlsConfigCondition.Status == metav1.ConditionTrue, ) diff --git a/internal/controller/supervisorconfig/oidcupstreamwatcher/oidc_upstream_watcher.go b/internal/controller/supervisorconfig/oidcupstreamwatcher/oidc_upstream_watcher.go index a2a85f21c..10ad7ad9a 100644 --- a/internal/controller/supervisorconfig/oidcupstreamwatcher/oidc_upstream_watcher.go +++ b/internal/controller/supervisorconfig/oidcupstreamwatcher/oidc_upstream_watcher.go @@ -6,7 +6,6 @@ package oidcupstreamwatcher import ( "context" - "crypto/sha256" "crypto/x509" "fmt" "net/http" @@ -334,7 +333,7 @@ func (c *oidcWatcherController) validateSecret(upstream *idpv1alpha1.OIDCIdentit // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. func (c *oidcWatcherController) validateIssuer(ctx context.Context, upstream *idpv1alpha1.OIDCIdentityProvider, result *upstreamoidc.ProviderConfig) []*metav1.Condition { - tlsCondition, caBundlePEM, certPool := tlsconfigutil.ValidateTLSConfig( + tlsCondition, caBundle := tlsconfigutil.ValidateTLSConfig( tlsconfigutil.TLSSpecForSupervisor(upstream.Spec.TLS), "spec.tls", upstream.Namespace, @@ -360,7 +359,7 @@ func (c *oidcWatcherController) validateIssuer(ctx context.Context, upstream *id // Get the discovered provider and HTTP client from cache, if they are found in the cache. cacheKey := oidcDiscoveryCacheKey{ issuer: upstream.Spec.Issuer, - caBundleHash: sha256.Sum256(caBundlePEM), // note that this will always return the same hash for nil input + caBundleHash: caBundle.GetCABundleHash(), // note that this will always return the same hash for nil input } if cacheEntry := c.validatorCache.getProvider(cacheKey); cacheEntry != nil { discoveredProvider = cacheEntry.provider @@ -374,7 +373,7 @@ func (c *oidcWatcherController) validateIssuer(ctx context.Context, upstream *id // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. if discoveredProvider == nil { - httpClient = defaultClientShortTimeout(certPool) + httpClient = defaultClientShortTimeout(caBundle.GetCertPool()) _, issuerURLCondition := validateHTTPSURL(upstream.Spec.Issuer, "issuer", reasonUnreachable) if issuerURLCondition != nil { diff --git a/internal/controller/supervisorconfig/upstreamwatchers/upstream_watchers.go b/internal/controller/supervisorconfig/upstreamwatchers/upstream_watchers.go index 7bb9abff6..183203941 100644 --- a/internal/controller/supervisorconfig/upstreamwatchers/upstream_watchers.go +++ b/internal/controller/supervisorconfig/upstreamwatchers/upstream_watchers.go @@ -256,9 +256,9 @@ func ValidateGenericLDAP( conditions.Append(secretValidCondition, true) tlsSpec := tlsconfigutil.TLSSpecForSupervisor(upstream.Spec().TLSSpec()) - tlsValidCondition, caBundle, _ := tlsconfigutil.ValidateTLSConfig(tlsSpec, "spec.tls", upstream.Namespace(), secretInformer, configMapInformer) + tlsValidCondition, caBundle := tlsconfigutil.ValidateTLSConfig(tlsSpec, "spec.tls", upstream.Namespace(), secretInformer, configMapInformer) conditions.Append(tlsValidCondition, true) - config.CABundle = caBundle + config.CABundle = caBundle.GetCABundle() var ldapConnectionValidCondition, searchBaseFoundCondition *metav1.Condition // No point in trying to connect to the server if the config was already determined to be invalid. diff --git a/internal/controller/tlsconfigutil/tls_config_util.go b/internal/controller/tlsconfigutil/tls_config_util.go index f41bb9d7c..a26e7738e 100644 --- a/internal/controller/tlsconfigutil/tls_config_util.go +++ b/internal/controller/tlsconfigutil/tls_config_util.go @@ -4,6 +4,7 @@ package tlsconfigutil import ( + "crypto/sha256" "crypto/x509" "encoding/base64" "fmt" @@ -80,18 +81,54 @@ func TLSSpecForConcierge(source *authenticationv1alpha1.TLSSpec) *TLSSpec { return dest } +// CABundle abstracts the internal representation of CA certificate bundles. +type CABundle struct { + caBundle []byte + caCertPool *x509.CertPool +} + +// GetCABundle returns the CA certificate bundle PEM bytes. +func (c *CABundle) GetCABundle() []byte { + return c.caBundle +} + +// GetCABundlePemString returns the certificate bundle PEM formatted as a string. +func (c *CABundle) GetCABundlePemString() string { + return string(c.caBundle) +} + +// GetCertPool returns a X509 cert pool with the CA certificate bundle. +func (c *CABundle) GetCertPool() *x509.CertPool { + return c.caCertPool +} + +// GetCABundleHash returns a sha256 sum of the CA bundle bytes. +func (c *CABundle) GetCABundleHash() [32]byte { + return sha256.Sum256(c.caBundle) // note that this will always return the same hash for nil input +} + +// IsEqual returns whether a CABundle has the same CA certificate bundle as another. +func (l *CABundle) IsEqual(r *CABundle) bool { + if l == nil && r == nil { + return true + } + if l == nil || r == nil { + return false + } + return sha256.Sum256(l.caBundle) == sha256.Sum256(r.GetCABundle()) +} + // ValidateTLSConfig reads ca bundle in the tlsSpec, supplied either inline using the CertificateAuthorityDate // or as a reference to a kubernetes secret or configmap using the CertificateAuthorityDataSource, and returns // - a condition of type TLSConfigurationValid based on the validity of the ca bundle, -// - a pem encoded ca bundle -// - a X509 cert pool with the ca bundle. +// - a CABundle - an abstraction of internal representation of CA certificate bundles. func ValidateTLSConfig( tlsSpec *TLSSpec, conditionPrefix string, namespace string, secretInformer corev1informers.SecretInformer, configMapInformer corev1informers.ConfigMapInformer, -) (*metav1.Condition, []byte, *x509.CertPool) { +) (*metav1.Condition, *CABundle) { // TODO: This func should return a struct that abstracts away the internals of how a CA bundle is held in memory // and can return the CA bundle as string PEM, []byte base64-encoded, CertPool, hash, etc, as well as compare itself // to either a different struct instance or a hash. @@ -100,13 +137,14 @@ func ValidateTLSConfig( certPool, bundle, err := getCertPool(tlsSpec, conditionPrefix, namespace, secretInformer, configMapInformer) if err != nil { - return invalidTLSCondition(err.Error()), nil, nil + return invalidTLSCondition(err.Error()), &CABundle{} } if bundle == nil { // An empty or nil CA bundle results in a valid TLS condition which indicates that no CA data was supplied. - return validTLSCondition(fmt.Sprintf("%s is valid: %s", conditionPrefix, noTLSConfigurationMessage)), nil, nil + return validTLSCondition(fmt.Sprintf("%s is valid: %s", conditionPrefix, noTLSConfigurationMessage)), nil } - return validTLSCondition(fmt.Sprintf("%s is valid: %s", conditionPrefix, loadedTLSConfigurationMessage)), bundle, certPool + return validTLSCondition(fmt.Sprintf("%s is valid: %s", conditionPrefix, loadedTLSConfigurationMessage)), + &CABundle{bundle, certPool} } // getCertPool reads the unified tlsSpec and returns an X509 cert pool with the CA data that is read either from diff --git a/internal/controller/tlsconfigutil/tls_config_util_test.go b/internal/controller/tlsconfigutil/tls_config_util_test.go index f7ba30ac3..eb626013a 100644 --- a/internal/controller/tlsconfigutil/tls_config_util_test.go +++ b/internal/controller/tlsconfigutil/tls_config_util_test.go @@ -27,18 +27,16 @@ import ( func TestValidateTLSConfig(t *testing.T) { testCA, err := certauthority.New("Test CA", 1*time.Hour) require.NoError(t, err) - bundle := testCA.Bundle() certPool := x509.NewCertPool() - require.True(t, certPool.AppendCertsFromPEM(bundle)) - base64EncodedBundle := base64.StdEncoding.EncodeToString(bundle) + require.True(t, certPool.AppendCertsFromPEM(testCA.Bundle())) + base64EncodedBundle := base64.StdEncoding.EncodeToString(testCA.Bundle()) tests := []struct { name string tlsSpec *TLSSpec namespace string k8sObjects []runtime.Object - expectedBundle []byte - expectedCertPool *x509.CertPool + expectedCABundle *CABundle expectedCondition *metav1.Condition }{ { @@ -66,8 +64,10 @@ func TestValidateTLSConfig(t *testing.T) { tlsSpec: &TLSSpec{ CertificateAuthorityData: base64EncodedBundle, }, - expectedBundle: bundle, - expectedCertPool: certPool, + expectedCABundle: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -134,12 +134,14 @@ func TestValidateTLSConfig(t *testing.T) { }, Type: corev1.SecretTypeTLS, Data: map[string][]byte{ - "ca-bundle": bundle, + "ca-bundle": testCA.Bundle(), }, }, }, - expectedBundle: bundle, - expectedCertPool: certPool, + expectedCABundle: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -165,12 +167,14 @@ func TestValidateTLSConfig(t *testing.T) { }, Type: corev1.SecretTypeOpaque, Data: map[string][]byte{ - "ca-bundle": bundle, + "ca-bundle": testCA.Bundle(), }, }, }, - expectedBundle: bundle, - expectedCertPool: certPool, + expectedCABundle: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -196,7 +200,7 @@ func TestValidateTLSConfig(t *testing.T) { }, Type: corev1.SecretTypeBasicAuth, Data: map[string][]byte{ - "ca-bundle": bundle, + "ca-bundle": testCA.Bundle(), }, }, }, @@ -225,7 +229,7 @@ func TestValidateTLSConfig(t *testing.T) { }, Type: corev1.SecretTypeOpaque, Data: map[string][]byte{ - "wrong-key": bundle, + "wrong-key": testCA.Bundle(), }, }, }, @@ -311,7 +315,7 @@ func TestValidateTLSConfig(t *testing.T) { Namespace: "awesome-namespace", }, Data: map[string]string{ - "wrong-key": string(bundle), + "wrong-key": string(testCA.Bundle()), }, }, }, @@ -395,12 +399,14 @@ func TestValidateTLSConfig(t *testing.T) { Namespace: "awesome-namespace", }, Data: map[string]string{ - "ca-bundle": string(bundle), + "ca-bundle": string(testCA.Bundle()), }, }, }, - expectedBundle: bundle, - expectedCertPool: certPool, + expectedCABundle: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, expectedCondition: &metav1.Condition{ Type: typeTLSConfigurationValid, Status: metav1.ConditionTrue, @@ -461,7 +467,7 @@ func TestValidateTLSConfig(t *testing.T) { Namespace: "awesome-namespace", }, Data: map[string]string{ - "ca-bundle": string(bundle), + "ca-bundle": string(testCA.Bundle()), }, }, }, @@ -499,11 +505,12 @@ func TestValidateTLSConfig(t *testing.T) { // which would do this same call for us. sharedInformers.WaitForCacheSync(ctx.Done()) - actualCondition, actualBundle, actualCertPool := ValidateTLSConfig(tt.tlsSpec, "spec.foo.tls", tt.namespace, secretsInformer, configMapInformer) + actualCondition, actualBundle := ValidateTLSConfig(tt.tlsSpec, "spec.foo.tls", tt.namespace, secretsInformer, configMapInformer) require.Equal(t, tt.expectedCondition, actualCondition) - require.Equal(t, tt.expectedBundle, actualBundle) - require.True(t, tt.expectedCertPool.Equal(actualCertPool), "expectedCertPool did not equal actualCertPool") + if tt.expectedCABundle != nil { + require.True(t, tt.expectedCABundle.IsEqual(actualBundle), "expectedCertPool did not equal actualCertPool") + } }) } } @@ -655,3 +662,68 @@ func TestTLSSpecForConcierge(t *testing.T) { }) } } + +func TestCABundleIsEqual(t *testing.T) { + testCA, err := certauthority.New("Test CA", 1*time.Hour) + require.NoError(t, err) + certPool := x509.NewCertPool() + require.True(t, certPool.AppendCertsFromPEM(testCA.Bundle())) + + tests := []struct { + name string + left *CABundle + right *CABundle + expected bool + }{ + { + name: "should return equal when left and right are nil", + left: nil, + right: nil, + expected: true, + }, + { + name: "should return not equal when left is nil and right is not", + left: nil, + right: &CABundle{}, + expected: false, + }, + { + name: "should return not equal when right is nil and left is not", + left: &CABundle{}, + right: nil, + expected: false, + }, + { + name: "should return equal when both left and right have same CA certificate bytes", + left: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, + right: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, + expected: true, + }, + { + name: "should return not equal when both left and right do not have same CA certificate bytes", + left: &CABundle{ + caBundle: testCA.Bundle(), + caCertPool: certPool, + }, + right: &CABundle{ + caBundle: []byte("something that is not a cert"), + caCertPool: nil, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := tt.left.IsEqual(tt.right) + require.Equal(t, tt.expected, actual) + }) + } +}