From 08ccf821f909faabef615cd796262465a2a90eca Mon Sep 17 00:00:00 2001 From: Ben McClelland Date: Tue, 15 Jul 2025 10:54:38 -0700 Subject: [PATCH] fix: refresh expired iam vault tokens when needed The IAM vault client stores an access token once authenticated, but this token will expire after a certain amount of time set by the server generating the token. Once this token is expired or revoked, it can no longer be use by the vault client. So the client should try to refresh the token with any errors indicating expired or revoked tokens. Fixes #976 --- auth/iam_vault.go | 160 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 113 insertions(+), 47 deletions(-) diff --git a/auth/iam_vault.go b/auth/iam_vault.go index 0c958520..cfdde4b7 100644 --- a/auth/iam_vault.go +++ b/auth/iam_vault.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -26,21 +27,25 @@ import ( "github.com/hashicorp/vault-client-go/schema" ) +const requestTimeout = 10 * time.Second + type VaultIAMService struct { client *vault.Client authReqOpts []vault.RequestOption kvReqOpts []vault.RequestOption secretStoragePath string rootAcc Account + creds schema.AppRoleLoginRequest } var _ IAMService = &VaultIAMService{} -func NewVaultIAMService(rootAcc Account, endpoint, secretStoragePath, authMethod, mountPath, rootToken, roleID, roleSecret, serverCert, clientCert, clientCertKey string) (IAMService, error) { +func NewVaultIAMService(rootAcc Account, endpoint, secretStoragePath, + authMethod, mountPath, rootToken, roleID, roleSecret, serverCert, + clientCert, clientCertKey string) (IAMService, error) { opts := []vault.ClientOption{ vault.WithAddress(endpoint), - // set request timeout to 10 secs - vault.WithRequestTimeout(10 * time.Second), + vault.WithRequestTimeout(requestTimeout), } if serverCert != "" { tls := vault.TLSConfiguration{} @@ -75,6 +80,11 @@ func NewVaultIAMService(rootAcc Account, endpoint, secretStoragePath, authMethod kvReqOpts = append(kvReqOpts, vault.WithMountPath(mountPath)) } + creds := schema.AppRoleLoginRequest{ + RoleId: roleID, + SecretId: roleSecret, + } + // Authentication switch { case rootToken != "": @@ -87,12 +97,8 @@ func NewVaultIAMService(rootAcc Account, endpoint, secretStoragePath, authMethod return nil, fmt.Errorf("role id and role secret must both be specified") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - resp, err := client.Auth.AppRoleLogin(ctx, schema.AppRoleLoginRequest{ - RoleId: roleID, - SecretId: roleSecret, - }, authReqOpts...) - cancel() + resp, err := client.Auth.AppRoleLogin(context.Background(), + creds, authReqOpts...) if err != nil { return nil, fmt.Errorf("approle authentication failure: %w", err) } @@ -110,30 +116,73 @@ func NewVaultIAMService(rootAcc Account, endpoint, secretStoragePath, authMethod kvReqOpts: kvReqOpts, secretStoragePath: secretStoragePath, rootAcc: rootAcc, + creds: creds, }, nil } +func (vt *VaultIAMService) reAuthIfNeeded(err error) error { + if err == nil { + return nil + } + + // Vault returns 403 for expired/revoked tokens + // pass all other errors back unchanged + if !vault.IsErrorStatus(err, http.StatusForbidden) { + return err + } + + resp, authErr := vt.client.Auth.AppRoleLogin(context.Background(), + vt.creds, vt.authReqOpts...) + if authErr != nil { + return fmt.Errorf("vault re-authentication failure: %w", authErr) + } + if err := vt.client.SetToken(resp.Auth.ClientToken); err != nil { + return fmt.Errorf("vault re-authentication set token failure: %w", err) + } + + return nil +} + func (vt *VaultIAMService) CreateAccount(account Account) error { if vt.rootAcc.Access == account.Access { return ErrUserExists } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - _, err := vt.client.Secrets.KvV2Write(ctx, vt.secretStoragePath+"/"+account.Access, schema.KvV2WriteRequest{ - Data: map[string]any{ - account.Access: account, - }, - Options: map[string]interface{}{ - "cas": 0, - }, - }, vt.kvReqOpts...) - cancel() + _, err := vt.client.Secrets.KvV2Write(context.Background(), + vt.secretStoragePath+"/"+account.Access, schema.KvV2WriteRequest{ + Data: map[string]any{ + account.Access: account, + }, + Options: map[string]any{ + "cas": 0, + }, + }, vt.kvReqOpts...) if err != nil { if strings.Contains(err.Error(), "check-and-set") { return ErrUserExists } - return err - } + reauthErr := vt.reAuthIfNeeded(err) + if reauthErr != nil { + return reauthErr + } + // retry once after re-auth + _, err = vt.client.Secrets.KvV2Write(context.Background(), + vt.secretStoragePath+"/"+account.Access, schema.KvV2WriteRequest{ + Data: map[string]any{ + account.Access: account, + }, + Options: map[string]any{ + "cas": 0, + }, + }, vt.kvReqOpts...) + if err != nil { + if strings.Contains(err.Error(), "check-and-set") { + return ErrUserExists + } + return err + } + return nil + } return nil } @@ -141,66 +190,84 @@ func (vt *VaultIAMService) GetUserAccount(access string) (Account, error) { if vt.rootAcc.Access == access { return vt.rootAcc, nil } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - resp, err := vt.client.Secrets.KvV2Read(ctx, vt.secretStoragePath+"/"+access, vt.kvReqOpts...) - cancel() + resp, err := vt.client.Secrets.KvV2Read(context.Background(), + vt.secretStoragePath+"/"+access, vt.kvReqOpts...) if err != nil { - return Account{}, err + reauthErr := vt.reAuthIfNeeded(err) + if reauthErr != nil { + return Account{}, reauthErr + } + // retry once after re-auth + resp, err = vt.client.Secrets.KvV2Read(context.Background(), + vt.secretStoragePath+"/"+access, vt.kvReqOpts...) + if err != nil { + return Account{}, err + } } - acc, err := parseVaultUserAccount(resp.Data.Data, access) if err != nil { return Account{}, err } - return acc, nil } func (vt *VaultIAMService) UpdateUserAccount(access string, props MutableProps) error { - //TODO: We need something like a transaction here ? acc, err := vt.GetUserAccount(access) if err != nil { return err } - updateAcc(&acc, props) - err = vt.DeleteUserAccount(access) if err != nil { return err } - err = vt.CreateAccount(acc) if err != nil { return err } - return nil } func (vt *VaultIAMService) DeleteUserAccount(access string) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - _, err := vt.client.Secrets.KvV2DeleteMetadataAndAllVersions(ctx, vt.secretStoragePath+"/"+access, vt.kvReqOpts...) - cancel() + _, err := vt.client.Secrets.KvV2DeleteMetadataAndAllVersions(context.Background(), + vt.secretStoragePath+"/"+access, vt.kvReqOpts...) if err != nil { - return err + reauthErr := vt.reAuthIfNeeded(err) + if reauthErr != nil { + return reauthErr + } + // retry once after re-auth + _, err = vt.client.Secrets.KvV2DeleteMetadataAndAllVersions(context.Background(), + vt.secretStoragePath+"/"+access, vt.kvReqOpts...) + if err != nil { + return err + } } return nil } func (vt *VaultIAMService) ListUserAccounts() ([]Account, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - resp, err := vt.client.Secrets.KvV2List(ctx, vt.secretStoragePath, vt.kvReqOpts...) - cancel() + resp, err := vt.client.Secrets.KvV2List(context.Background(), + vt.secretStoragePath, vt.kvReqOpts...) if err != nil { - if vault.IsErrorStatus(err, 404) { - return []Account{}, nil + reauthErr := vt.reAuthIfNeeded(err) + if reauthErr != nil { + if vault.IsErrorStatus(err, http.StatusNotFound) { + return []Account{}, nil + } + return nil, reauthErr + } + // retry once after re-auth + resp, err = vt.client.Secrets.KvV2List(context.Background(), + vt.secretStoragePath, vt.kvReqOpts...) + if err != nil { + if vault.IsErrorStatus(err, http.StatusNotFound) { + return []Account{}, nil + } + return nil, err } - return nil, err } - accs := []Account{} - for _, acss := range resp.Data.Keys { acc, err := vt.GetUserAccount(acss) if err != nil { @@ -208,7 +275,6 @@ func (vt *VaultIAMService) ListUserAccounts() ([]Account, error) { } accs = append(accs, acc) } - return accs, nil } @@ -219,8 +285,8 @@ func (vt *VaultIAMService) Shutdown() error { var errInvalidUser error = errors.New("invalid user account entry in secrets engine") -func parseVaultUserAccount(data map[string]interface{}, access string) (acc Account, err error) { - usrAcc, ok := data[access].(map[string]interface{}) +func parseVaultUserAccount(data map[string]any, access string) (acc Account, err error) { + usrAcc, ok := data[access].(map[string]any) if !ok { return acc, errInvalidUser }