mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-07 14:05:50 +00:00
Upstream ldap group refresh:
- Doing it inline on the refresh request
This commit is contained in:
@@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
|
||||
authenticators.UserAuthenticator
|
||||
|
||||
// PerformRefresh performs a refresh against the upstream LDAP identity provider
|
||||
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) error
|
||||
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error)
|
||||
}
|
||||
|
||||
type StoredRefreshAttributes struct {
|
||||
|
||||
@@ -301,7 +301,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
// run PerformRefresh
|
||||
err = p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||
Username: username,
|
||||
Subject: subject,
|
||||
DN: dn,
|
||||
@@ -312,6 +312,10 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
"Upstream refresh failed.").WithWrap(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
// If we got groups back, then replace the old value with the new value.
|
||||
if groups != nil {
|
||||
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1339,6 +1339,60 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path refresh grant when the upstream refresh returns new group memberships from LDAP, it updates groups",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
customSessionData: happyLDAPCustomSessionData,
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||
happyLDAPCustomSessionData,
|
||||
),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantStatus: http.StatusOK,
|
||||
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||
wantGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path refresh grant when the upstream refresh returns empty list of group memberships from LDAP, it updates groups to an empty list",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: []string{},
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
customSessionData: happyLDAPCustomSessionData,
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||
happyLDAPCustomSessionData,
|
||||
),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantStatus: http.StatusOK,
|
||||
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||
wantGroups: []string{},
|
||||
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
||||
@@ -1967,9 +2021,10 @@ func TestRefreshGrant(t *testing.T) {
|
||||
{
|
||||
name: "upstream ldap refresh happy path",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: goodGroups,
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
@@ -1987,9 +2042,10 @@ func TestRefreshGrant(t *testing.T) {
|
||||
{
|
||||
name: "upstream active directory refresh happy path",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithActiveDirectory(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: activeDirectoryUpstreamName,
|
||||
ResourceUID: activeDirectoryUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
Name: activeDirectoryUpstreamName,
|
||||
ResourceUID: activeDirectoryUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: goodGroups,
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
|
||||
@@ -101,6 +101,7 @@ type TestUpstreamLDAPIdentityProvider struct {
|
||||
performRefreshCallCount int
|
||||
performRefreshArgs []*PerformRefreshArgs
|
||||
PerformRefreshErr error
|
||||
PerformRefreshGroups []string
|
||||
}
|
||||
|
||||
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
|
||||
@@ -121,7 +122,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
||||
return u.URL
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
@@ -133,9 +134,9 @@ func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, s
|
||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||
})
|
||||
if u.PerformRefreshErr != nil {
|
||||
return u.PerformRefreshErr
|
||||
return nil, u.PerformRefreshErr
|
||||
}
|
||||
return nil
|
||||
return u.PerformRefreshGroups, nil
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int {
|
||||
|
||||
@@ -170,61 +170,11 @@ func (p *Provider) GetConfig() ProviderConfig {
|
||||
return p.c
|
||||
}
|
||||
|
||||
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
|
||||
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
|
||||
userDN := storedRefreshAttributes.DN
|
||||
|
||||
searchResult, err := p.performRefresh(ctx, userDN)
|
||||
if err != nil {
|
||||
p.traceRefreshFailure(t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// if any more or less than one entry, error.
|
||||
// we don't need to worry about logging this because we know it's a dn.
|
||||
if len(searchResult.Entries) != 1 {
|
||||
return fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
||||
userDN, len(searchResult.Entries),
|
||||
)
|
||||
}
|
||||
|
||||
userEntry := searchResult.Entries[0]
|
||||
if len(userEntry.DN) == 0 {
|
||||
return fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
||||
}
|
||||
|
||||
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newUsername != storedRefreshAttributes.Username {
|
||||
return fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
||||
userDN, storedRefreshAttributes.Username, newUsername,
|
||||
)
|
||||
}
|
||||
|
||||
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
||||
if newSubject != storedRefreshAttributes.Subject {
|
||||
return fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
||||
}
|
||||
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
||||
err = validateFunc(userEntry, storedRefreshAttributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
||||
}
|
||||
}
|
||||
// we checked that the user still exists and their information is the same, so just return.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.SearchResult, error) {
|
||||
search := p.refreshUserSearchRequest(userDN)
|
||||
|
||||
conn, err := p.dial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
|
||||
@@ -236,6 +186,65 @@ func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.Sea
|
||||
return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err)
|
||||
}
|
||||
|
||||
searchResult, err := p.performUserRefresh(conn, userDN)
|
||||
if err != nil {
|
||||
p.traceRefreshFailure(t, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if any more or less than one entry, error.
|
||||
// we don't need to worry about logging this because we know it's a dn.
|
||||
if len(searchResult.Entries) != 1 {
|
||||
return nil, fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
||||
userDN, len(searchResult.Entries),
|
||||
)
|
||||
}
|
||||
|
||||
userEntry := searchResult.Entries[0]
|
||||
if len(userEntry.DN) == 0 {
|
||||
return nil, fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
||||
}
|
||||
|
||||
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if newUsername != storedRefreshAttributes.Username {
|
||||
return nil, fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
||||
userDN, storedRefreshAttributes.Username, newUsername,
|
||||
)
|
||||
}
|
||||
|
||||
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
||||
if newSubject != storedRefreshAttributes.Subject {
|
||||
return nil, fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
||||
}
|
||||
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
||||
err = validateFunc(userEntry, storedRefreshAttributes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have group search configured, search for groups to update the value.
|
||||
if len(p.c.GroupSearch.Base) > 0 {
|
||||
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(mappedGroupNames)
|
||||
return mappedGroupNames, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *Provider) performUserRefresh(conn Conn, userDN string) (*ldap.SearchResult, error) {
|
||||
search := p.refreshUserSearchRequest(userDN)
|
||||
|
||||
searchResult, err := conn.Search(search)
|
||||
|
||||
if err != nil {
|
||||
@@ -455,7 +464,7 @@ func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, er
|
||||
groupAttributeName = distinguishedNameAttributeName
|
||||
}
|
||||
|
||||
var groups []string
|
||||
groups := []string{}
|
||||
entries:
|
||||
for _, groupEntry := range searchResult.Entries {
|
||||
if len(groupEntry.DN) == 0 {
|
||||
|
||||
@@ -1100,6 +1100,18 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
Controls: nil, // don't need paging because we set the SizeLimit so small
|
||||
}
|
||||
|
||||
expectedGroupSearch := &ldap.SearchRequest{
|
||||
BaseDN: testGroupSearchBase,
|
||||
Scope: ldap.ScopeWholeSubtree,
|
||||
DerefAliases: ldap.NeverDerefAliases,
|
||||
SizeLimit: 0, // unlimited size because we will search with paging
|
||||
TimeLimit: 90,
|
||||
TypesOnly: false,
|
||||
Filter: testGroupSearchFilterInterpolated,
|
||||
Attributes: []string{testGroupSearchGroupNameAttribute},
|
||||
Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us
|
||||
}
|
||||
|
||||
happyPathUserSearchResult := &ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
@@ -1146,6 +1158,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
setupMocks func(conn *mockldapconn.MockConn)
|
||||
dialError error
|
||||
wantErr string
|
||||
wantGroups []string
|
||||
}{
|
||||
{
|
||||
name: "happy path where searching the dn returns a single entry",
|
||||
@@ -1156,6 +1169,89 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path where group search returns groups",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
DN: testGroupSearchResultDNValue1,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}),
|
||||
},
|
||||
},
|
||||
{
|
||||
DN: testGroupSearchResultDNValue2,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}),
|
||||
},
|
||||
},
|
||||
},
|
||||
Referrals: []string{}, // note that we are not following referrals at this time
|
||||
Controls: []ldap.Control{},
|
||||
}, nil).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2},
|
||||
},
|
||||
{
|
||||
name: "happy path where group search returns no groups",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{},
|
||||
Referrals: []string{}, // note that we are not following referrals at this time
|
||||
Controls: []ldap.Control{},
|
||||
}, nil).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantGroups: []string{},
|
||||
},
|
||||
{
|
||||
name: "error where dial fails",
|
||||
providerConfig: providerConfig,
|
||||
@@ -1421,6 +1517,37 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
},
|
||||
wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login",
|
||||
},
|
||||
{
|
||||
name: "group search returns an error",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantErr: "error searching for group memberships for user with DN \"some-upstream-user-dn\": some search error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -1435,9 +1562,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
dialWasAttempted := false
|
||||
providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
dialWasAttempted = true
|
||||
require.Equal(t, providerConfig.Host, addr.Endpoint())
|
||||
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
|
||||
if tt.dialError != nil {
|
||||
return nil, tt.dialError
|
||||
}
|
||||
@@ -1446,9 +1573,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
|
||||
ldapProvider := New(*providerConfig)
|
||||
ldapProvider := New(*tt.providerConfig)
|
||||
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
|
||||
err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
||||
groups, err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
||||
Username: testUserSearchResultUsernameAttributeValue,
|
||||
Subject: subject,
|
||||
DN: testUserSearchResultDNValue,
|
||||
@@ -1461,6 +1588,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, true, dialWasAttempted)
|
||||
require.Equal(t, tt.wantGroups, groups)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user