cleanup: minor fixes to ldap exported functions and test

The buildSearchFilter function doesn't need to be exported, and
can use strings.Builder. Add a unit test to make sure this didn't
change any logic.

This will also use the debuglogger to enable debugging.
This commit is contained in:
Ben McClelland
2025-08-29 16:25:35 -07:00
parent 24b1c45db3
commit b358e385db
6 changed files with 107 additions and 25 deletions

View File

@@ -119,7 +119,6 @@ type Opts struct {
LDAPRoleAtr string LDAPRoleAtr string
LDAPUserIdAtr string LDAPUserIdAtr string
LDAPGroupIdAtr string LDAPGroupIdAtr string
LDAPDebug bool
VaultEndpointURL string VaultEndpointURL string
VaultSecretStoragePath string VaultSecretStoragePath string
VaultAuthMethod string VaultAuthMethod string
@@ -159,7 +158,7 @@ func New(o *Opts) (IAMService, error) {
case o.LDAPServerURL != "": case o.LDAPServerURL != "":
svc, err = NewLDAPService(o.RootAccount, o.LDAPServerURL, o.LDAPBindDN, o.LDAPPassword, svc, err = NewLDAPService(o.RootAccount, o.LDAPServerURL, o.LDAPBindDN, o.LDAPPassword,
o.LDAPQueryBase, o.LDAPAccessAtr, o.LDAPSecretAtr, o.LDAPRoleAtr, o.LDAPUserIdAtr, o.LDAPQueryBase, o.LDAPAccessAtr, o.LDAPSecretAtr, o.LDAPRoleAtr, o.LDAPUserIdAtr,
o.LDAPGroupIdAtr, o.LDAPObjClasses, o.LDAPDebug) o.LDAPGroupIdAtr, o.LDAPObjClasses)
fmt.Printf("initializing LDAP IAM with %q\n", o.LDAPServerURL) fmt.Printf("initializing LDAP IAM with %q\n", o.LDAPServerURL)
case o.S3Endpoint != "": case o.S3Endpoint != "":
svc, err = NewS3(o.RootAccount, o.S3Access, o.S3Secret, o.S3Region, o.S3Bucket, svc, err = NewS3(o.RootAccount, o.S3Access, o.S3Secret, o.S3Region, o.S3Bucket,

View File

@@ -22,6 +22,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/versity/versitygw/debuglogger"
) )
type LdapIAMService struct { type LdapIAMService struct {
@@ -33,7 +34,6 @@ type LdapIAMService struct {
roleAtr string roleAtr string
groupIdAtr string groupIdAtr string
userIdAtr string userIdAtr string
debug bool
rootAcc Account rootAcc Account
url string url string
bindDN string bindDN string
@@ -43,7 +43,7 @@ type LdapIAMService struct {
var _ IAMService = &LdapIAMService{} var _ IAMService = &LdapIAMService{}
func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAtr, roleAtr, userIdAtr, groupIdAtr, objClasses string, debug bool) (IAMService, error) { func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAtr, roleAtr, userIdAtr, groupIdAtr, objClasses string) (IAMService, error) {
if url == "" || bindDN == "" || pass == "" || queryBase == "" || accAtr == "" || if url == "" || bindDN == "" || pass == "" || queryBase == "" || accAtr == "" ||
secAtr == "" || roleAtr == "" || userIdAtr == "" || groupIdAtr == "" || objClasses == "" { secAtr == "" || roleAtr == "" || userIdAtr == "" || groupIdAtr == "" || objClasses == "" {
return nil, fmt.Errorf("required parameters list not fully provided") return nil, fmt.Errorf("required parameters list not fully provided")
@@ -65,7 +65,6 @@ func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAt
secretAtr: secAtr, secretAtr: secAtr,
roleAtr: roleAtr, roleAtr: roleAtr,
userIdAtr: userIdAtr, userIdAtr: userIdAtr,
debug: debug,
groupIdAtr: groupIdAtr, groupIdAtr: groupIdAtr,
rootAcc: rootAcc, rootAcc: rootAcc,
url: url, url: url,
@@ -129,15 +128,15 @@ func (ld *LdapIAMService) CreateAccount(account Account) error {
return nil return nil
} }
func (ld *LdapIAMService) BuildSearchFilter(access string) string { func (ld *LdapIAMService) buildSearchFilter(access string) string {
searchFilter := "" var searchFilter strings.Builder
for _, el := range ld.objClasses { for _, el := range ld.objClasses {
searchFilter += fmt.Sprintf("(objectClass=%v)", el) searchFilter.WriteString(fmt.Sprintf("(objectClass=%v)", el))
} }
if access != "" { if access != "" {
searchFilter += fmt.Sprintf("(%v=%v)", ld.accessAtr, access) searchFilter.WriteString(fmt.Sprintf("(%v=%v)", ld.accessAtr, access))
} }
return fmt.Sprintf("(&%v)", searchFilter) return fmt.Sprintf("(&%v)", searchFilter.String())
} }
func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) {
@@ -152,13 +151,14 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) {
0, 0,
0, 0,
false, false,
ld.BuildSearchFilter(access), ld.buildSearchFilter(access),
[]string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.userIdAtr, ld.groupIdAtr}, []string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.userIdAtr, ld.groupIdAtr},
nil, nil,
) )
if ld.debug { if debuglogger.IsIAMDebugEnabled() {
spew.Dump(searchRequest) debuglogger.IAMLogf("LDAP Search Request")
debuglogger.IAMLogf(spew.Sdump(searchRequest))
} }
err := ld.execute(func(c *ldap.Conn) error { err := ld.execute(func(c *ldap.Conn) error {
@@ -167,8 +167,9 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) {
return err return err
}) })
if ld.debug { if debuglogger.IsIAMDebugEnabled() {
spew.Dump(result) debuglogger.IAMLogf("LDAP Search Result")
debuglogger.IAMLogf(spew.Sdump(result))
} }
if err != nil { if err != nil {
@@ -246,7 +247,7 @@ func (ld *LdapIAMService) ListUserAccounts() ([]Account, error) {
0, 0,
0, 0,
false, false,
ld.BuildSearchFilter(""), ld.buildSearchFilter(""),
[]string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.groupIdAtr, ld.userIdAtr}, []string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.groupIdAtr, ld.userIdAtr},
nil, nil,
) )

56
auth/iam_ldap_test.go Normal file
View File

@@ -0,0 +1,56 @@
package auth
import "testing"
func TestLdapIAMService_BuildSearchFilter(t *testing.T) {
tests := []struct {
name string
objClasses []string
accessAtr string
access string
expected string
}{
{
name: "single object class with access",
objClasses: []string{"inetOrgPerson"},
accessAtr: "uid",
access: "testuser",
expected: "(&(objectClass=inetOrgPerson)(uid=testuser))",
},
{
name: "single object class without access",
objClasses: []string{"inetOrgPerson"},
accessAtr: "uid",
access: "",
expected: "(&(objectClass=inetOrgPerson))",
},
{
name: "multiple object classes with access",
objClasses: []string{"inetOrgPerson", "organizationalPerson"},
accessAtr: "cn",
access: "john.doe",
expected: "(&(objectClass=inetOrgPerson)(objectClass=organizationalPerson)(cn=john.doe))",
},
{
name: "multiple object classes without access",
objClasses: []string{"inetOrgPerson", "organizationalPerson", "person"},
accessAtr: "cn",
access: "",
expected: "(&(objectClass=inetOrgPerson)(objectClass=organizationalPerson)(objectClass=person))",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ld := &LdapIAMService{
objClasses: tt.objClasses,
accessAtr: tt.accessAtr,
}
result := ld.buildSearchFilter(tt.access)
if result != tt.expected {
t.Errorf("BuildSearchFilter() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -64,7 +64,6 @@ var (
ldapQueryBase, ldapObjClasses string ldapQueryBase, ldapObjClasses string
ldapAccessAtr, ldapSecAtr, ldapRoleAtr string ldapAccessAtr, ldapSecAtr, ldapRoleAtr string
ldapUserIdAtr, ldapGroupIdAtr string ldapUserIdAtr, ldapGroupIdAtr string
ldapDebug bool
vaultEndpointURL, vaultSecretStoragePath string vaultEndpointURL, vaultSecretStoragePath string
vaultAuthMethod, vaultMountPath string vaultAuthMethod, vaultMountPath string
vaultRootToken, vaultRoleId string vaultRootToken, vaultRoleId string
@@ -83,6 +82,7 @@ var (
ipaHost, ipaVaultName string ipaHost, ipaVaultName string
ipaUser, ipaPassword string ipaUser, ipaPassword string
ipaInsecure, ipaDebug bool ipaInsecure, ipaDebug bool
iamDebug bool
) )
var ( var (
@@ -400,12 +400,6 @@ func initFlags() []cli.Flag {
EnvVars: []string{"VGW_IAM_LDAP_GROUP_ID_ATR"}, EnvVars: []string{"VGW_IAM_LDAP_GROUP_ID_ATR"},
Destination: &ldapGroupIdAtr, Destination: &ldapGroupIdAtr,
}, },
&cli.BoolFlag{
Name: "iam-ldap-debug",
Usage: "ldap server debug output",
EnvVars: []string{"VGW_IAM_LDAP_DEBUG"},
Destination: &ldapDebug,
},
&cli.StringFlag{ &cli.StringFlag{
Name: "iam-vault-endpoint-url", Name: "iam-vault-endpoint-url",
Usage: "vault server url", Usage: "vault server url",
@@ -529,6 +523,13 @@ func initFlags() []cli.Flag {
Value: 3600, Value: 3600,
Destination: &iamCachePrune, Destination: &iamCachePrune,
}, },
&cli.BoolFlag{
Name: "iam-debug",
Usage: "enable IAM debug output",
Value: false,
EnvVars: []string{"VGW_IAM_DEBUG"},
Destination: &iamDebug,
},
&cli.StringFlag{ &cli.StringFlag{
Name: "health", Name: "health",
Usage: `health check endpoint path. Health endpoint will be configured on GET http method: GET <health> Usage: `health check endpoint path. Health endpoint will be configured on GET http method: GET <health>
@@ -660,6 +661,10 @@ func runGateway(ctx context.Context, be backend.Backend) error {
debuglogger.SetDebugEnabled() debuglogger.SetDebugEnabled()
} }
if iamDebug {
debuglogger.SetIAMDebugEnabled()
}
iam, err := auth.New(&auth.Opts{ iam, err := auth.New(&auth.Opts{
RootAccount: auth.Account{ RootAccount: auth.Account{
Access: rootUserAccess, Access: rootUserAccess,
@@ -677,7 +682,6 @@ func runGateway(ctx context.Context, be backend.Backend) error {
LDAPRoleAtr: ldapRoleAtr, LDAPRoleAtr: ldapRoleAtr,
LDAPUserIdAtr: ldapUserIdAtr, LDAPUserIdAtr: ldapUserIdAtr,
LDAPGroupIdAtr: ldapGroupIdAtr, LDAPGroupIdAtr: ldapGroupIdAtr,
LDAPDebug: ldapDebug,
VaultEndpointURL: vaultEndpointURL, VaultEndpointURL: vaultEndpointURL,
VaultSecretStoragePath: vaultSecretStoragePath, VaultSecretStoragePath: vaultSecretStoragePath,
VaultAuthMethod: vaultAuthMethod, VaultAuthMethod: vaultAuthMethod,

View File

@@ -115,6 +115,28 @@ func Infof(format string, v ...any) {
fmt.Printf(string(green)+debugPrefix+format+reset+"\n", v...) fmt.Printf(string(green)+debugPrefix+format+reset+"\n", v...)
} }
var debugIAMEnabled atomic.Bool
// SetIAMDebugEnabled sets the IAM debug mode
func SetIAMDebugEnabled() {
debugIAMEnabled.Store(true)
}
// IsDebugEnabled returns true if debugging enabled
func IsIAMDebugEnabled() bool {
return debugEnabled.Load()
}
// IAMLogf is the same as 'fmt.Printf' with debug prefix,
// a color added and '\n' at the end
func IAMLogf(format string, v ...any) {
if !debugIAMEnabled.Load() {
return
}
debugPrefix := "[DEBUG]: "
fmt.Printf(string(yellow)+debugPrefix+format+reset+"\n", v...)
}
// PrintInsideHorizontalBorders prints the text inside horizontal // PrintInsideHorizontalBorders prints the text inside horizontal
// border and title in the center of upper border // border and title in the center of upper border
func PrintInsideHorizontalBorders(color Color, title, text string, width int) { func PrintInsideHorizontalBorders(color Color, title, text string, width int) {

2
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.38.1 github.com/aws/aws-sdk-go-v2 v1.38.1
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.1
github.com/aws/smithy-go v1.22.5 github.com/aws/smithy-go v1.22.5
github.com/davecgh/go-spew v1.1.1
github.com/go-ldap/ldap/v3 v3.4.11 github.com/go-ldap/ldap/v3 v3.4.11
github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/fiber/v2 v2.52.9
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.7.0
@@ -41,7 +42,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.28.2 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.28.2 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.2 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.38.0 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.38.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect