mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-03 11:45:45 +00:00
Address PR feedback
This commit is contained in:
committed by
Joshua Casey
parent
76a116641f
commit
f7fd209f29
@@ -199,7 +199,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co
|
||||
)
|
||||
} else {
|
||||
// Run all remaining validations.
|
||||
a, moreConditions, moreErrs := c.doExpensiveValidations(webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger)
|
||||
a, moreConditions, moreErrs := c.doExpensiveValidations(ctx, webhookAuthenticator, endpointHostPort, caBundle, okSoFar, logger)
|
||||
newWebhookAuthenticatorForCache = a
|
||||
conditions = append(conditions, moreConditions...)
|
||||
errs = append(errs, moreErrs...)
|
||||
@@ -238,6 +238,7 @@ func (c *webhookCacheFillerController) syncIndividualWebhookAuthenticator(ctx co
|
||||
}
|
||||
|
||||
func (c *webhookCacheFillerController) doExpensiveValidations(
|
||||
ctx context.Context,
|
||||
webhookAuthenticator *authenticationv1alpha1.WebhookAuthenticator,
|
||||
endpointHostPort *endpointaddr.HostPort,
|
||||
caBundle *tlsconfigutil.CABundle,
|
||||
@@ -248,7 +249,7 @@ func (c *webhookCacheFillerController) doExpensiveValidations(
|
||||
var conditions []*metav1.Condition
|
||||
var errs []error
|
||||
|
||||
conditions, tlsNegotiateErr := c.validateConnection(caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger)
|
||||
conditions, tlsNegotiateErr := c.validateConnection(ctx, caBundle.CertPool(), endpointHostPort, conditions, okSoFar, logger)
|
||||
errs = append(errs, tlsNegotiateErr)
|
||||
okSoFar = okSoFar && tlsNegotiateErr == nil
|
||||
|
||||
@@ -414,6 +415,7 @@ func successfulWebhookConnectionValidCondition() *metav1.Condition {
|
||||
}
|
||||
|
||||
func (c *webhookCacheFillerController) validateConnection(
|
||||
ctx context.Context,
|
||||
certPool *x509.CertPool,
|
||||
endpointHostPort *endpointaddr.HostPort,
|
||||
conditions []*metav1.Condition,
|
||||
@@ -430,7 +432,9 @@ func (c *webhookCacheFillerController) validateConnection(
|
||||
return conditions, nil
|
||||
}
|
||||
|
||||
err := c.dialer.IsReachableAndTLSValidationSucceeds(endpointHostPort.Endpoint(), certPool, logger)
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
err := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, endpointHostPort.Endpoint(), certPool, logger)
|
||||
|
||||
if err != nil {
|
||||
errText := "cannot dial server"
|
||||
|
||||
@@ -189,7 +189,7 @@ func (c *gitHubWatcherController) Sync(ctx controllerlib.Context) error {
|
||||
var applicationErrors []error
|
||||
validatedUpstreams := make([]upstreamprovider.UpstreamGithubIdentityProviderI, 0, len(actualUpstreams))
|
||||
for _, upstream := range actualUpstreams {
|
||||
validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx, upstream)
|
||||
validatedUpstream, applicationErr := c.validateUpstreamAndUpdateConditions(ctx.Context, upstream)
|
||||
if applicationErr != nil {
|
||||
applicationErrors = append(applicationErrors, applicationErr)
|
||||
} else if validatedUpstream != nil {
|
||||
@@ -297,7 +297,7 @@ func validateOrganizationsPolicy(organizationsSpec *idpv1alpha1.GitHubOrganizati
|
||||
}
|
||||
}
|
||||
|
||||
func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx controllerlib.Context, upstream *idpv1alpha1.GitHubIdentityProvider) (
|
||||
func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx context.Context, upstream *idpv1alpha1.GitHubIdentityProvider) (
|
||||
*upstreamgithub.Provider, // If validated, returns the config
|
||||
error, // This error will only refer to programmatic errors such as inability to perform a Dial or dereference a pointer, not configuration errors
|
||||
) {
|
||||
@@ -331,6 +331,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
|
||||
conditions = append(conditions, tlsConfigCondition)
|
||||
|
||||
githubConnectionCondition, httpClient, githubConnectionErr := c.validateGitHubConnection(
|
||||
ctx,
|
||||
apiHostPort,
|
||||
upstream.Spec.GitHubAPI.Host,
|
||||
caBundle,
|
||||
@@ -349,7 +350,7 @@ func (c *gitHubWatcherController) validateUpstreamAndUpdateConditions(ctx contro
|
||||
applicationErrors = append(applicationErrors, fmt.Errorf("expected %d conditions but found %d conditions", countExpectedConditions, len(conditions)))
|
||||
return nil, utilerrors.NewAggregate(applicationErrors)
|
||||
}
|
||||
hadErrorCondition, updateStatusErr := c.updateStatus(ctx.Context, upstream, conditions)
|
||||
hadErrorCondition, updateStatusErr := c.updateStatus(ctx, upstream, conditions)
|
||||
if updateStatusErr != nil {
|
||||
applicationErrors = append(applicationErrors, updateStatusErr)
|
||||
}
|
||||
@@ -453,6 +454,7 @@ func validateHost(specifiedHost *string) (*metav1.Condition, *endpointaddr.HostP
|
||||
}
|
||||
|
||||
func (c *gitHubWatcherController) validateGitHubConnection(
|
||||
ctx context.Context,
|
||||
apiHostPort *endpointaddr.HostPort,
|
||||
specifiedHost *string,
|
||||
caBundle *tlsconfigutil.CABundle,
|
||||
@@ -470,7 +472,10 @@ func (c *gitHubWatcherController) validateGitHubConnection(
|
||||
apiAddress := apiHostPort.Endpoint()
|
||||
|
||||
if !c.validatedCache.IsValid(apiAddress, caBundle.Hash()) {
|
||||
tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(apiAddress, caBundle.CertPool(), c.log)
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
|
||||
tlsDialErr := c.dialer.IsReachableAndTLSValidationSucceeds(dialCtx, apiAddress, caBundle.CertPool(), c.log)
|
||||
if tlsDialErr != nil {
|
||||
return &metav1.Condition{
|
||||
Type: GitHubConnectionValid,
|
||||
|
||||
@@ -67,10 +67,10 @@ type fakeGithubDialer struct {
|
||||
realCertPool *x509.CertPool
|
||||
}
|
||||
|
||||
func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(address string, _ *x509.CertPool, logger ptls.ErrorOnlyLogger) error {
|
||||
func (f fakeGithubDialer) IsReachableAndTLSValidationSucceeds(ctx context.Context, address string, _ *x509.CertPool, logger plog.Logger) error {
|
||||
require.Equal(f.t, "api.github.com:443", address)
|
||||
|
||||
return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(f.realAddress, f.realCertPool, logger)
|
||||
return ptls.NewDialer().IsReachableAndTLSValidationSucceeds(ctx, f.realAddress, f.realCertPool, logger)
|
||||
}
|
||||
|
||||
var _ ptls.Dialer = (*fakeGithubDialer)(nil)
|
||||
@@ -79,7 +79,7 @@ type allowNoDials struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ string, _ *x509.CertPool, _ ptls.ErrorOnlyLogger) error {
|
||||
func (f allowNoDials) IsReachableAndTLSValidationSucceeds(_ context.Context, _ string, _ *x509.CertPool, _ plog.Logger) error {
|
||||
f.t.Errorf("this test should not perform dial")
|
||||
f.t.FailNow()
|
||||
return nil
|
||||
|
||||
@@ -4,52 +4,56 @@
|
||||
package ptls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
type Dialer interface {
|
||||
IsReachableAndTLSValidationSucceeds(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
certPool *x509.CertPool,
|
||||
logger ErrorOnlyLogger,
|
||||
logger plog.Logger,
|
||||
) error
|
||||
}
|
||||
|
||||
type ErrorOnlyLogger interface {
|
||||
Error(msg string, err error, keysAndValues ...any)
|
||||
}
|
||||
|
||||
type internalDialer struct {
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
func NewDialer() *internalDialer {
|
||||
return &internalDialer{
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *internalDialer) WithTimeout(timeout time.Duration) Dialer {
|
||||
i.dialer.Timeout = timeout
|
||||
return i
|
||||
return &internalDialer{}
|
||||
}
|
||||
|
||||
func (i *internalDialer) IsReachableAndTLSValidationSucceeds(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
certPool *x509.CertPool,
|
||||
logger ErrorOnlyLogger,
|
||||
logger plog.Logger,
|
||||
) error {
|
||||
connection, err := tls.DialWithDialer(i.dialer, "tcp", address, Default(certPool))
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
_, hasDeadline := ctx.Deadline()
|
||||
if !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
dialer := tls.Dialer{
|
||||
Config: Default(certPool),
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
// Don't wrap this error message since this is just a helper function.
|
||||
return err
|
||||
}
|
||||
err = connection.Close()
|
||||
err = conn.Close()
|
||||
if err != nil { // untested
|
||||
// Log it just so that it doesn't completely disappear.
|
||||
logger.Error("Failed to close connection: ", err)
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
package ptls_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http/httptest"
|
||||
@@ -15,19 +17,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pinniped.dev/internal/crypto/ptls"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/internal/testutil/tlsserver"
|
||||
)
|
||||
|
||||
type fakeerroronlylogger struct {
|
||||
}
|
||||
|
||||
func (_ *fakeerroronlylogger) Error(msg string, err error, keysAndValues ...any) {
|
||||
// NOOP
|
||||
}
|
||||
|
||||
var _ ptls.ErrorOnlyLogger = (*fakeerroronlylogger)(nil)
|
||||
|
||||
func TestDialer(t *testing.T) {
|
||||
secureServerIPv4, secureServerIPv4CA := tlsserver.TestServerIPv4(t, nil, nil)
|
||||
secureServerIPv6, secureServerIPv6CA := tlsserver.TestServerIPv6(t, nil, nil)
|
||||
@@ -69,10 +63,14 @@ func TestDialer(t *testing.T) {
|
||||
t.Parallel()
|
||||
dialer := ptls.NewDialer()
|
||||
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
context.Background(),
|
||||
urlToAddress(t, test.fullURL),
|
||||
test.certPool,
|
||||
&fakeerroronlylogger{},
|
||||
logger,
|
||||
)
|
||||
if test.wantError != "" {
|
||||
require.EqualError(t, err, test.wantError)
|
||||
@@ -86,51 +84,94 @@ func TestDialer(t *testing.T) {
|
||||
func TestDialer_TimeoutAfter15s(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dialer := ptls.NewDialer()
|
||||
dialTimeout := 15 * time.Second
|
||||
|
||||
timeout := time.After(30 * time.Second)
|
||||
testDone := make(chan bool)
|
||||
maxDurationForTest := 2 * dialTimeout
|
||||
maxTimeForTest := time.After(maxDurationForTest)
|
||||
testPassed := make(chan bool)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
|
||||
defer cancel()
|
||||
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
dialer := ptls.NewDialer()
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
ctx, // replace with context.Background() to verify that this hangs indefinitely
|
||||
setupHangingServer(t),
|
||||
nil,
|
||||
&fakeerroronlylogger{},
|
||||
logger,
|
||||
)
|
||||
require.EqualError(t, err, "context deadline exceeded")
|
||||
testDone <- true
|
||||
testPassed <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Errorf("test did not complete within 30 seconds")
|
||||
case <-maxTimeForTest:
|
||||
t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest)
|
||||
t.FailNow()
|
||||
case <-testDone:
|
||||
case <-testPassed:
|
||||
t.Log("everything ok!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialer_WithCustomTimeTimeoutAfter2s(t *testing.T) {
|
||||
func TestDialer_WithoutDeadline_Uses30sTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dialer := ptls.NewDialer().WithTimeout(2 * time.Second)
|
||||
|
||||
timeout := time.After(5 * time.Second)
|
||||
testDone := make(chan bool)
|
||||
maxDurationForTest := 40 * time.Second
|
||||
maxTimeForTest := time.After(maxDurationForTest)
|
||||
testPassed := make(chan bool)
|
||||
go func() {
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
dialer := ptls.NewDialer()
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
context.Background(),
|
||||
setupHangingServer(t),
|
||||
nil,
|
||||
&fakeerroronlylogger{},
|
||||
logger,
|
||||
)
|
||||
require.EqualError(t, err, "context deadline exceeded")
|
||||
testDone <- true
|
||||
testPassed <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Errorf("test did not complete within 5 seconds")
|
||||
case <-maxTimeForTest:
|
||||
t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest)
|
||||
t.FailNow()
|
||||
case <-testDone:
|
||||
case <-testPassed:
|
||||
t.Log("everything ok!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialer_WithNilContext_Uses30sTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
maxDurationForTest := 40 * time.Second
|
||||
maxTimeForTest := time.After(maxDurationForTest)
|
||||
testPassed := make(chan bool)
|
||||
go func() {
|
||||
var log bytes.Buffer
|
||||
logger := plog.TestLogger(t, &log)
|
||||
|
||||
dialer := ptls.NewDialer()
|
||||
err := dialer.IsReachableAndTLSValidationSucceeds(
|
||||
nil,
|
||||
setupHangingServer(t),
|
||||
nil,
|
||||
logger,
|
||||
)
|
||||
require.EqualError(t, err, "context deadline exceeded")
|
||||
testPassed <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-maxTimeForTest:
|
||||
t.Errorf("timeout not honored: test did not complete within %s", maxDurationForTest)
|
||||
t.FailNow()
|
||||
case <-testPassed:
|
||||
t.Log("everything ok!")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user