From ab2c2e30cbe1de5bd9f559ebb389acd5f5b744b3 Mon Sep 17 00:00:00 2001 From: Ashish Amarnath Date: Thu, 19 Sep 2024 13:18:07 -0700 Subject: [PATCH] refactor and fix comments Signed-off-by: Ashish Amarnath --- internal/upstreamldap/upstreamldap.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 9beee329c..ac5695844 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -330,18 +330,12 @@ func (p *Provider) dialTLS(ctx context.Context, addr endpointaddr.HostPort) (Con return conn, nil } -// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS. +// dialStartTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { - tlsConfig, err := p.tlsConfig() - if err != nil { - return nil, ldap.NewError(ldap.ErrorNetwork, err) - } - - // Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS. - tlsConfig.ServerName = addr.Host - + // start with a plaintext tcp connection which will then be upgraded to + // and LDAP connection over TLS. c, err := netDialer().DialContext(ctx, "tcp", addr.Endpoint()) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) @@ -349,6 +343,12 @@ func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) conn := ldap.NewConn(c, false) conn.Start() + tlsConfig, err := p.tlsConfig() + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) + } + // Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS. + tlsConfig.ServerName = addr.Host err = conn.StartTLS(tlsConfig) if err != nil { return nil, err