diff --git a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher.go b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher.go index 81837769f..47cb404d6 100644 --- a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher.go +++ b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher.go @@ -152,8 +152,9 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream * spec := upstream.Spec config := &upstreamldap.ProviderConfig{ - Name: upstream.Name, - Host: spec.Host, + Name: upstream.Name, + Host: spec.Host, + ConnectionProtocol: upstreamldap.TLS, UserSearch: upstreamldap.UserSearchConfig{ Base: spec.UserSearch.Base, Filter: spec.UserSearch.Filter, diff --git a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go index 99f09f1e8..ae9346f56 100644 --- a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go @@ -197,11 +197,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { } providerConfigForValidUpstream := &upstreamldap.ProviderConfig{ - Name: testName, - Host: testHost, - CABundle: testCABundle, - BindUsername: testBindUsername, - BindPassword: testBindPassword, + Name: testName, + Host: testHost, + ConnectionProtocol: upstreamldap.TLS, + CABundle: testCABundle, + BindUsername: testBindUsername, + BindPassword: testBindPassword, UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, @@ -442,11 +443,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, wantResultingCache: []*upstreamldap.ProviderConfig{ { - Name: testName, - Host: testHost, - CABundle: nil, - BindUsername: testBindUsername, - BindPassword: testBindPassword, + Name: testName, + Host: testHost, + ConnectionProtocol: upstreamldap.TLS, + CABundle: nil, + BindUsername: testBindUsername, + BindPassword: testBindPassword, UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, @@ -493,11 +495,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, wantResultingCache: []*upstreamldap.ProviderConfig{ { - Name: testName, - Host: testHost, - CABundle: nil, - BindUsername: testBindUsername, - BindPassword: testBindPassword, + Name: testName, + Host: testHost, + ConnectionProtocol: upstreamldap.TLS, + CABundle: nil, + BindUsername: testBindUsername, + BindPassword: testBindPassword, UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 2ffec1ed6..39d51e8cf 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -60,6 +60,13 @@ func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, err return f(ctx, hostAndPort) } +type LDAPConnectionProtocol string + +const ( + StartTLS = LDAPConnectionProtocol("StartTLS") + TLS = LDAPConnectionProtocol("TLS") +) + // ProviderConfig includes all of the settings for connection and searching for users and groups in // the upstream LDAP IDP. It also provides methods for testing the connection and performing logins. // The nested structs are not pointer fields to enable deep copy on function params and return values. @@ -71,6 +78,9 @@ type ProviderConfig struct { // the default LDAP port will be used. Host string + // ConnectionProtocol determines how to establish the connection to the server. Either StartTLS or TLS. + ConnectionProtocol LDAPConnectionProtocol + // PEM-encoded CA cert bundle to trust when connecting to the LDAP server. Can be nil. CABundle []byte @@ -137,33 +147,38 @@ func (p *Provider) GetConfig() ProviderConfig { } func (p *Provider) dial(ctx context.Context) (Conn, error) { - hostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) + tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - if p.c.Dialer != nil { - return p.c.Dialer.Dial(ctx, hostAndPort) + + startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort) + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) + } + + switch { + case p.c.Dialer != nil: + return p.c.Dialer.Dial(ctx, tlsHostAndPort) + case p.c.ConnectionProtocol == TLS: + return p.dialTLS(ctx, tlsHostAndPort) + case p.c.ConnectionProtocol == StartTLS: + return p.dialStartTLS(ctx, startTLSHostAndPort) + default: + return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol")) } - return p.dialTLS(ctx, hostAndPort) } -// dialTLS is the default implementation of the Dialer, used when Dialer is nil. +// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { - var rootCAs *x509.CertPool - if p.c.CABundle != nil { - rootCAs = x509.NewCertPool() - if !rootCAs.AppendCertsFromPEM(p.c.CABundle) { - return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle")) - } + tlsConfig, err := p.tlsConfig() + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) } - dialer := &tls.Dialer{Config: &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: rootCAs, - }} - + dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig} c, err := dialer.DialContext(ctx, "tcp", hostAndPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) @@ -174,6 +189,52 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error return conn, nil } +// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS. +// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, +// so we implement it ourselves, heavily inspired by ldap.DialURL. +func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) { + tlsConfig, err := p.tlsConfig() + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) + } + + host, err := hostWithoutPort(hostAndPort) + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) + } + // Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS. + tlsConfig.ServerName = host + + c, err := netDialer().DialContext(ctx, "tcp", hostAndPort) + if err != nil { + return nil, ldap.NewError(ldap.ErrorNetwork, err) + } + + conn := ldap.NewConn(c, false) + conn.Start() + err = conn.StartTLS(tlsConfig) + if err != nil { + return nil, err + } + + return conn, nil +} + +func netDialer() *net.Dialer { + return &net.Dialer{Timeout: time.Minute} +} + +func (p *Provider) tlsConfig() (*tls.Config, error) { + var rootCAs *x509.CertPool + if p.c.CABundle != nil { + rootCAs = x509.NewCertPool() + if !rootCAs.AppendCertsFromPEM(p.c.CABundle) { + return nil, fmt.Errorf("could not parse CA bundle") + } + } + return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil +} + // Adds the default port if hostAndPort did not already include a port. func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) { host, port, err := net.SplitHostPort(hostAndPort) @@ -188,7 +249,7 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, switch { case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"): // don't add extra square brackets to an IPv6 address that already has them - return host + ":" + port, nil + return fmt.Sprintf("%s:%s", host, port), nil case port != "": return net.JoinHostPort(host, port), nil default: @@ -196,6 +257,22 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, } } +// Strip the port from a host or host:port. +func hostWithoutPort(hostAndPort string) (string, error) { + host, _, err := net.SplitHostPort(hostAndPort) + if err != nil { + if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare + return hostAndPort, nil + } + return "", err // hostAndPort argument was not parsable + } + if strings.HasPrefix(hostAndPort, "[") { + // it was an IPv6 address, so preserve the square brackets. + return fmt.Sprintf("[%s]", host), nil + } + return host, nil +} + // A name for this upstream provider. func (p *Provider) GetName() string { return p.c.Name diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 365618f4f..56203fbbb 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -1132,27 +1132,55 @@ func TestRealTLSDialing(t *testing.T) { tests := []struct { name string host string + connProto LDAPConnectionProtocol caBundle []byte context context.Context wantError string }{ { - name: "happy path", - host: testServerHostAndPort, - caBundle: []byte(testServerCABundle), - context: context.Background(), + name: "happy path", + host: testServerHostAndPort, + caBundle: []byte(testServerCABundle), + connProto: TLS, + context: context.Background(), }, { - name: "invalid CA bundle", + name: "invalid CA bundle with TLS", host: testServerHostAndPort, caBundle: []byte("not a ca bundle"), + connProto: TLS, context: context.Background(), wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`, }, + { + name: "invalid CA bundle with StartTLS", + host: testServerHostAndPort, + caBundle: []byte("not a ca bundle"), + connProto: StartTLS, + context: context.Background(), + wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`, + }, + { + name: "invalid host with TLS", + host: "this:is:not:a:valid:hostname", + caBundle: []byte(testServerCABundle), + connProto: TLS, + context: context.Background(), + wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + }, + { + name: "invalid host with StartTLS", + host: "this:is:not:a:valid:hostname", + caBundle: []byte(testServerCABundle), + connProto: StartTLS, + context: context.Background(), + wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + }, { name: "missing CA bundle when it is required because the host is not using a trusted CA", host: testServerHostAndPort, caBundle: nil, + connProto: TLS, context: context.Background(), wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, }, @@ -1161,6 +1189,7 @@ func TestRealTLSDialing(t *testing.T) { // This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough. host: recentlyClaimedHostAndPort, caBundle: []byte(testServerCABundle), + connProto: TLS, context: context.Background(), wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort), }, @@ -1168,25 +1197,35 @@ func TestRealTLSDialing(t *testing.T) { name: "pays attention to the passed context", host: testServerHostAndPort, caBundle: []byte(testServerCABundle), + connProto: TLS, context: alreadyCancelledContext, wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort), }, + { + name: "unsupported connection protocol", + host: testServerHostAndPort, + caBundle: []byte(testServerCABundle), + connProto: "bad usage of this type", + context: alreadyCancelledContext, + wantError: `LDAP Result Code 200 "Network Error": did not specify valid ConnectionProtocol`, + }, } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { + tt := test + t.Run(tt.name, func(t *testing.T) { provider := New(ProviderConfig{ - Host: test.host, - CABundle: test.caBundle, - Dialer: nil, // this test is for the default (production) dialer + Host: tt.host, + CABundle: tt.caBundle, + ConnectionProtocol: tt.connProto, + Dialer: nil, // this test is for the default (production) TLS dialer }) - conn, err := provider.dial(test.context) + conn, err := provider.dial(tt.context) if conn != nil { defer conn.Close() } - if test.wantError != "" { + if tt.wantError != "" { require.Nil(t, conn) - require.EqualError(t, err, test.wantError) + require.EqualError(t, err, tt.wantError) } else { require.NoError(t, err) require.NotNil(t, conn) @@ -1231,6 +1270,12 @@ func TestHostAndPortWithDefaultPort(t *testing.T) { defaultPort: "", wantHostAndPort: "host.example.com", }, + { + name: "host has port and default port is empty", + hostAndPort: "host.example.com:42", + defaultPort: "", + wantHostAndPort: "host.example.com:42", + }, { name: "IPv6 host already has port", hostAndPort: "[::1%lo0]:80", @@ -1257,15 +1302,63 @@ func TestHostAndPortWithDefaultPort(t *testing.T) { }, } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - hostAndPort, err := hostAndPortWithDefaultPort(test.hostAndPort, test.defaultPort) - if test.wantError != "" { - require.EqualError(t, err, test.wantError) + tt := test + t.Run(tt.name, func(t *testing.T) { + hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort) + if tt.wantError != "" { + require.EqualError(t, err, tt.wantError) } else { require.NoError(t, err) } - require.Equal(t, test.wantHostAndPort, hostAndPort) + require.Equal(t, tt.wantHostAndPort, hostAndPort) + }) + } +} + +// Test various cases of host and port parsing. +func TestHostWithoutPort(t *testing.T) { + tests := []struct { + name string + hostAndPort string + wantError string + wantHostAndPort string + }{ + { + name: "host already has port", + hostAndPort: "host.example.com:99", + wantHostAndPort: "host.example.com", + }, + { + name: "host does not have port", + hostAndPort: "host.example.com", + wantHostAndPort: "host.example.com", + }, + { + name: "IPv6 host already has port", + hostAndPort: "[::1%lo0]:80", + wantHostAndPort: "[::1%lo0]", + }, + { + name: "IPv6 host does not have port", + hostAndPort: "[::1%lo0]", + wantHostAndPort: "[::1%lo0]", + }, + { + name: "host is not valid", + hostAndPort: "host.example.com:port1:port2", + wantError: "address host.example.com:port1:port2: too many colons in address", + }, + } + for _, test := range tests { + tt := test + t.Run(tt.name, func(t *testing.T) { + hostAndPort, err := hostWithoutPort(tt.hostAndPort) + if tt.wantError != "" { + require.EqualError(t, err, tt.wantError) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantHostAndPort, hostAndPort) }) } } diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index 3216dde9a..d62b51f4b 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -37,15 +37,19 @@ func TestLDAPSearch(t *testing.T) { cancelFunc() // this will send SIGKILL to the subprocess, just in case }) - hostPorts := findRecentlyUnusedLocalhostPorts(t, 2) - ldapHostPort := hostPorts[0] - unusedHostPort := hostPorts[1] + localhostPorts := findRecentlyUnusedLocalhostPorts(t, 3) + ldapLocalhostPort := localhostPorts[0] + ldapsLocalhostPort := localhostPorts[1] + unusedLocalhostPort := localhostPorts[2] // Expose the the test LDAP server's TLS port on the localhost. - startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace) + startKubectlPortForward(ctx, t, ldapsLocalhostPort, "ldaps", "ldap", env.ToolsNamespace) + + // Expose the the test LDAP server's StartTLS port on the localhost. + startKubectlPortForward(ctx, t, ldapLocalhostPort, "ldap", "ldap", env.ToolsNamespace) providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig { - providerConfig := defaultProviderConfig(env, ldapHostPort) + providerConfig := defaultProviderConfig(env, ldapsLocalhostPort) if editFunc != nil { editFunc(providerConfig) } @@ -64,7 +68,7 @@ func TestLDAPSearch(t *testing.T) { wantUnauthenticated bool }{ { - name: "happy path", + name: "happy path with TLS", username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(nil)), @@ -72,6 +76,18 @@ func TestLDAPSearch(t *testing.T) { User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}}, }, }, + { + name: "happy path with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + ldapLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + })), + wantAuthResponse: &authenticator.Response{ + User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}}, + }, + }, { name: "using a different user search base", username: "pinny", @@ -251,6 +267,17 @@ func TestLDAPSearch(t *testing.T) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })), wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, }, + { + name: "when the bind user username is wrong with StartTLS: example of an error after successful connection with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + ldapLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" + })), + wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, + }, { name: "when the end user password is wrong", username: "pinny", @@ -296,32 +323,89 @@ func TestLDAPSearch(t *testing.T) { wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `, }, { - name: "when the server is unreachable", + name: "when the server is unreachable with TLS", username: "pinny", password: pinnyPassword, - provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedHostPort })), - wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedLocalhostPort })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort), }, { - name: "when the server is not parsable", + name: "when the server is unreachable with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + unusedLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort), + }, + { + name: "when the server is not parsable with TLS", username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, }, { - name: "when the CA bundle is not parsable", + name: "when the server is not parsable with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + ldapLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + p.Host = "too:many:ports" + })), + wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, + }, + { + name: "when the CA bundle is not parsable with TLS", username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })), - wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapsLocalhostPort), }, { - name: "when the CA bundle does not cause the host to be trusted", + name: "when the CA bundle is not parsable with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + ldapLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + p.CABundle = []byte("invalid-pem") + })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapLocalhostPort), + }, + { + name: "when the CA bundle does not cause the host to be trusted with TLS", username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })), - wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapsLocalhostPort), + }, + { + name: "when the CA bundle does not cause the host to be trusted with StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { + p.Host = "127.0.0.1:" + ldapLocalhostPort + p.ConnectionProtocol = upstreamldap.StartTLS + p.CABundle = nil + })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": TLS handshake failed (x509: certificate signed by unknown authority)`, ldapLocalhostPort), + }, + { + name: "when trying to use TLS to connect to a port which only supports StartTLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + ldapLocalhostPort })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": EOF`, ldapLocalhostPort), + }, + { + name: "when trying to use StartTLS to connect to a port which only supports TLS", + username: "pinny", + password: pinnyPassword, + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.ConnectionProtocol = upstreamldap.StartTLS })), + wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": unable to read LDAP response packet: unexpected EOF`, ldapsLocalhostPort), }, { name: "when the UsernameAttribute attribute has multiple values in the entry", @@ -541,13 +625,14 @@ type authUserResult struct { err error } -func defaultProviderConfig(env *library.TestEnv, ldapHostPort string) *upstreamldap.ProviderConfig { +func defaultProviderConfig(env *library.TestEnv, port string) *upstreamldap.ProviderConfig { return &upstreamldap.ProviderConfig{ - Name: "test-ldap-provider", - Host: "127.0.0.1:" + ldapHostPort, - CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle), - BindUsername: "cn=admin,dc=pinniped,dc=dev", - BindPassword: "password", + Name: "test-ldap-provider", + Host: "127.0.0.1:" + port, + ConnectionProtocol: upstreamldap.TLS, + CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle), + BindUsername: "cn=admin,dc=pinniped,dc=dev", + BindPassword: "password", UserSearch: upstreamldap.UserSearchConfig{ Base: "ou=users,dc=pinniped,dc=dev", Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case