diff --git a/auth/iam_ldap.go b/auth/iam_ldap.go index 1be77fb..06b602f 100644 --- a/auth/iam_ldap.go +++ b/auth/iam_ldap.go @@ -18,6 +18,7 @@ import ( "fmt" "strconv" "strings" + "sync" "github.com/go-ldap/ldap/v3" ) @@ -32,6 +33,10 @@ type LdapIAMService struct { groupIdAtr string userIdAtr string rootAcc Account + url string + bindDN string + pass string + mu sync.Mutex } var _ IAMService = &LdapIAMService{} @@ -60,9 +65,45 @@ func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAt userIdAtr: userIdAtr, groupIdAtr: groupIdAtr, rootAcc: rootAcc, + url: url, + bindDN: bindDN, + pass: pass, }, nil } +func (ld *LdapIAMService) reconnect() error { + ld.conn.Close() + + conn, err := ldap.DialURL(ld.url) + if err != nil { + return fmt.Errorf("failed to reconnect to LDAP server: %w", err) + } + + err = conn.Bind(ld.bindDN, ld.pass) + if err != nil { + conn.Close() + return fmt.Errorf("failed to bind to LDAP server on reconnect: %w", err) + } + ld.conn = conn + return nil +} + +func (ld *LdapIAMService) execute(f func(*ldap.Conn) error) error { + ld.mu.Lock() + defer ld.mu.Unlock() + + err := f(ld.conn) + if err != nil { + if e, ok := err.(*ldap.Error); ok && e.ResultCode == ldap.ErrorNetwork { + if reconnErr := ld.reconnect(); reconnErr != nil { + return reconnErr + } + return f(ld.conn) + } + } + return err +} + func (ld *LdapIAMService) CreateAccount(account Account) error { if ld.rootAcc.Access == account.Access { return ErrUserExists @@ -75,7 +116,9 @@ func (ld *LdapIAMService) CreateAccount(account Account) error { userEntry.Attribute(ld.groupIdAtr, []string{fmt.Sprint(account.GroupID)}) userEntry.Attribute(ld.userIdAtr, []string{fmt.Sprint(account.UserID)}) - err := ld.conn.Add(userEntry) + err := ld.execute(func(c *ldap.Conn) error { + return c.Add(userEntry) + }) if err != nil { return fmt.Errorf("error adding an entry: %w", err) } @@ -87,6 +130,7 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { if access == ld.rootAcc.Access { return ld.rootAcc, nil } + var result *ldap.SearchResult searchRequest := ldap.NewSearchRequest( ld.queryBase, ldap.ScopeWholeSubtree, @@ -99,7 +143,11 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { nil, ) - result, err := ld.conn.Search(searchRequest) + err := ld.execute(func(c *ldap.Conn) error { + var err error + result, err = c.Search(searchRequest) + return err + }) if err != nil { return Account{}, err } @@ -143,7 +191,9 @@ func (ld *LdapIAMService) UpdateUserAccount(access string, props MutableProps) e req.Replace(ld.roleAtr, []string{string(props.Role)}) } - err := ld.conn.Modify(req) + err := ld.execute(func(c *ldap.Conn) error { + return c.Modify(req) + }) //TODO: Handle non existing user case if err != nil { return err @@ -154,7 +204,9 @@ func (ld *LdapIAMService) UpdateUserAccount(access string, props MutableProps) e func (ld *LdapIAMService) DeleteUserAccount(access string) error { delReq := ldap.NewDelRequest(fmt.Sprintf("%v=%v, %v", ld.accessAtr, access, ld.queryBase), nil) - err := ld.conn.Del(delReq) + err := ld.execute(func(c *ldap.Conn) error { + return c.Del(delReq) + }) if err != nil { return err } @@ -167,6 +219,7 @@ func (ld *LdapIAMService) ListUserAccounts() ([]Account, error) { for _, el := range ld.objClasses { searchFilter += fmt.Sprintf("(objectClass=%v)", el) } + var resp *ldap.SearchResult searchRequest := ldap.NewSearchRequest( ld.queryBase, ldap.ScopeWholeSubtree, @@ -179,7 +232,11 @@ func (ld *LdapIAMService) ListUserAccounts() ([]Account, error) { nil, ) - resp, err := ld.conn.Search(searchRequest) + err := ld.execute(func(c *ldap.Conn) error { + var err error + resp, err = c.Search(searchRequest) + return err + }) if err != nil { return nil, err } @@ -210,5 +267,7 @@ func (ld *LdapIAMService) ListUserAccounts() ([]Account, error) { // Shutdown graceful termination of service func (ld *LdapIAMService) Shutdown() error { + ld.mu.Lock() + defer ld.mu.Unlock() return ld.conn.Close() }