diff --git a/auth/iam.go b/auth/iam.go index 1e66462..a702c65 100644 --- a/auth/iam.go +++ b/auth/iam.go @@ -119,7 +119,6 @@ type Opts struct { LDAPRoleAtr string LDAPUserIdAtr string LDAPGroupIdAtr string - LDAPDebug bool VaultEndpointURL string VaultSecretStoragePath string VaultAuthMethod string @@ -159,7 +158,7 @@ func New(o *Opts) (IAMService, error) { case o.LDAPServerURL != "": svc, err = NewLDAPService(o.RootAccount, o.LDAPServerURL, o.LDAPBindDN, o.LDAPPassword, o.LDAPQueryBase, o.LDAPAccessAtr, o.LDAPSecretAtr, o.LDAPRoleAtr, o.LDAPUserIdAtr, - o.LDAPGroupIdAtr, o.LDAPObjClasses, o.LDAPDebug) + o.LDAPGroupIdAtr, o.LDAPObjClasses) fmt.Printf("initializing LDAP IAM with %q\n", o.LDAPServerURL) case o.S3Endpoint != "": svc, err = NewS3(o.RootAccount, o.S3Access, o.S3Secret, o.S3Region, o.S3Bucket, diff --git a/auth/iam_ldap.go b/auth/iam_ldap.go index cdda803..c16d862 100644 --- a/auth/iam_ldap.go +++ b/auth/iam_ldap.go @@ -22,6 +22,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-ldap/ldap/v3" + "github.com/versity/versitygw/debuglogger" ) type LdapIAMService struct { @@ -33,7 +34,6 @@ type LdapIAMService struct { roleAtr string groupIdAtr string userIdAtr string - debug bool rootAcc Account url string bindDN string @@ -43,7 +43,7 @@ type LdapIAMService struct { var _ IAMService = &LdapIAMService{} -func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAtr, roleAtr, userIdAtr, groupIdAtr, objClasses string, debug bool) (IAMService, error) { +func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAtr, roleAtr, userIdAtr, groupIdAtr, objClasses string) (IAMService, error) { if url == "" || bindDN == "" || pass == "" || queryBase == "" || accAtr == "" || secAtr == "" || roleAtr == "" || userIdAtr == "" || groupIdAtr == "" || objClasses == "" { return nil, fmt.Errorf("required parameters list not fully provided") @@ -65,7 +65,6 @@ func NewLDAPService(rootAcc Account, url, bindDN, pass, queryBase, accAtr, secAt secretAtr: secAtr, roleAtr: roleAtr, userIdAtr: userIdAtr, - debug: debug, groupIdAtr: groupIdAtr, rootAcc: rootAcc, url: url, @@ -129,15 +128,15 @@ func (ld *LdapIAMService) CreateAccount(account Account) error { return nil } -func (ld *LdapIAMService) BuildSearchFilter(access string) string { - searchFilter := "" +func (ld *LdapIAMService) buildSearchFilter(access string) string { + var searchFilter strings.Builder for _, el := range ld.objClasses { - searchFilter += fmt.Sprintf("(objectClass=%v)", el) + searchFilter.WriteString(fmt.Sprintf("(objectClass=%v)", el)) } if access != "" { - searchFilter += fmt.Sprintf("(%v=%v)", ld.accessAtr, access) + searchFilter.WriteString(fmt.Sprintf("(%v=%v)", ld.accessAtr, access)) } - return fmt.Sprintf("(&%v)", searchFilter) + return fmt.Sprintf("(&%v)", searchFilter.String()) } func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { @@ -152,13 +151,14 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { 0, 0, false, - ld.BuildSearchFilter(access), + ld.buildSearchFilter(access), []string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.userIdAtr, ld.groupIdAtr}, nil, ) - if ld.debug { - spew.Dump(searchRequest) + if debuglogger.IsIAMDebugEnabled() { + debuglogger.IAMLogf("LDAP Search Request") + debuglogger.IAMLogf(spew.Sdump(searchRequest)) } err := ld.execute(func(c *ldap.Conn) error { @@ -167,8 +167,9 @@ func (ld *LdapIAMService) GetUserAccount(access string) (Account, error) { return err }) - if ld.debug { - spew.Dump(result) + if debuglogger.IsIAMDebugEnabled() { + debuglogger.IAMLogf("LDAP Search Result") + debuglogger.IAMLogf(spew.Sdump(result)) } if err != nil { @@ -246,7 +247,7 @@ func (ld *LdapIAMService) ListUserAccounts() ([]Account, error) { 0, 0, false, - ld.BuildSearchFilter(""), + ld.buildSearchFilter(""), []string{ld.accessAtr, ld.secretAtr, ld.roleAtr, ld.groupIdAtr, ld.userIdAtr}, nil, ) diff --git a/auth/iam_ldap_test.go b/auth/iam_ldap_test.go new file mode 100644 index 0000000..7b77df1 --- /dev/null +++ b/auth/iam_ldap_test.go @@ -0,0 +1,56 @@ +package auth + +import "testing" + +func TestLdapIAMService_BuildSearchFilter(t *testing.T) { + tests := []struct { + name string + objClasses []string + accessAtr string + access string + expected string + }{ + { + name: "single object class with access", + objClasses: []string{"inetOrgPerson"}, + accessAtr: "uid", + access: "testuser", + expected: "(&(objectClass=inetOrgPerson)(uid=testuser))", + }, + { + name: "single object class without access", + objClasses: []string{"inetOrgPerson"}, + accessAtr: "uid", + access: "", + expected: "(&(objectClass=inetOrgPerson))", + }, + { + name: "multiple object classes with access", + objClasses: []string{"inetOrgPerson", "organizationalPerson"}, + accessAtr: "cn", + access: "john.doe", + expected: "(&(objectClass=inetOrgPerson)(objectClass=organizationalPerson)(cn=john.doe))", + }, + { + name: "multiple object classes without access", + objClasses: []string{"inetOrgPerson", "organizationalPerson", "person"}, + accessAtr: "cn", + access: "", + expected: "(&(objectClass=inetOrgPerson)(objectClass=organizationalPerson)(objectClass=person))", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ld := &LdapIAMService{ + objClasses: tt.objClasses, + accessAtr: tt.accessAtr, + } + + result := ld.buildSearchFilter(tt.access) + if result != tt.expected { + t.Errorf("BuildSearchFilter() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/cmd/versitygw/main.go b/cmd/versitygw/main.go index 5adde6c..635408a 100644 --- a/cmd/versitygw/main.go +++ b/cmd/versitygw/main.go @@ -64,7 +64,6 @@ var ( ldapQueryBase, ldapObjClasses string ldapAccessAtr, ldapSecAtr, ldapRoleAtr string ldapUserIdAtr, ldapGroupIdAtr string - ldapDebug bool vaultEndpointURL, vaultSecretStoragePath string vaultAuthMethod, vaultMountPath string vaultRootToken, vaultRoleId string @@ -83,6 +82,7 @@ var ( ipaHost, ipaVaultName string ipaUser, ipaPassword string ipaInsecure, ipaDebug bool + iamDebug bool ) var ( @@ -400,12 +400,6 @@ func initFlags() []cli.Flag { EnvVars: []string{"VGW_IAM_LDAP_GROUP_ID_ATR"}, Destination: &ldapGroupIdAtr, }, - &cli.BoolFlag{ - Name: "iam-ldap-debug", - Usage: "ldap server debug output", - EnvVars: []string{"VGW_IAM_LDAP_DEBUG"}, - Destination: &ldapDebug, - }, &cli.StringFlag{ Name: "iam-vault-endpoint-url", Usage: "vault server url", @@ -529,6 +523,13 @@ func initFlags() []cli.Flag { Value: 3600, Destination: &iamCachePrune, }, + &cli.BoolFlag{ + Name: "iam-debug", + Usage: "enable IAM debug output", + Value: false, + EnvVars: []string{"VGW_IAM_DEBUG"}, + Destination: &iamDebug, + }, &cli.StringFlag{ Name: "health", Usage: `health check endpoint path. Health endpoint will be configured on GET http method: GET @@ -660,6 +661,10 @@ func runGateway(ctx context.Context, be backend.Backend) error { debuglogger.SetDebugEnabled() } + if iamDebug { + debuglogger.SetIAMDebugEnabled() + } + iam, err := auth.New(&auth.Opts{ RootAccount: auth.Account{ Access: rootUserAccess, @@ -677,7 +682,6 @@ func runGateway(ctx context.Context, be backend.Backend) error { LDAPRoleAtr: ldapRoleAtr, LDAPUserIdAtr: ldapUserIdAtr, LDAPGroupIdAtr: ldapGroupIdAtr, - LDAPDebug: ldapDebug, VaultEndpointURL: vaultEndpointURL, VaultSecretStoragePath: vaultSecretStoragePath, VaultAuthMethod: vaultAuthMethod, diff --git a/debuglogger/logger.go b/debuglogger/logger.go index e849856..c38d493 100644 --- a/debuglogger/logger.go +++ b/debuglogger/logger.go @@ -115,6 +115,28 @@ func Infof(format string, v ...any) { fmt.Printf(string(green)+debugPrefix+format+reset+"\n", v...) } +var debugIAMEnabled atomic.Bool + +// SetIAMDebugEnabled sets the IAM debug mode +func SetIAMDebugEnabled() { + debugIAMEnabled.Store(true) +} + +// IsDebugEnabled returns true if debugging enabled +func IsIAMDebugEnabled() bool { + return debugEnabled.Load() +} + +// IAMLogf is the same as 'fmt.Printf' with debug prefix, +// a color added and '\n' at the end +func IAMLogf(format string, v ...any) { + if !debugIAMEnabled.Load() { + return + } + debugPrefix := "[DEBUG]: " + fmt.Printf(string(yellow)+debugPrefix+format+reset+"\n", v...) +} + // PrintInsideHorizontalBorders prints the text inside horizontal // border and title in the center of upper border func PrintInsideHorizontalBorders(color Color, title, text string, width int) { diff --git a/go.mod b/go.mod index 1540460..5ce411a 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.38.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.87.1 github.com/aws/smithy-go v1.22.5 + github.com/davecgh/go-spew v1.1.1 github.com/go-ldap/ldap/v3 v3.4.11 github.com/gofiber/fiber/v2 v2.52.9 github.com/google/go-cmp v0.7.0 @@ -41,7 +42,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.28.2 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.33.2 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.38.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect