mirror of
https://github.com/vmware-tanzu/pinniped.git
synced 2026-01-03 11:45:45 +00:00
Improve unit tests in tokenclient_test.go
Also fix a linter error and rename some new files.
This commit is contained in:
@@ -5,7 +5,6 @@ package tokenclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -58,43 +57,50 @@ func New(
|
||||
return client
|
||||
}
|
||||
|
||||
func (tokenClient TokenClient) Start(ctx context.Context) {
|
||||
func (tc TokenClient) Start(ctx context.Context) {
|
||||
sleeper := make(chan time.Time, 1)
|
||||
|
||||
// Make sure that the <-sleeper below gets run once immediately.
|
||||
sleeper <- time.Now()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
tokenClient.logger.Info("TokenClient was cancelled and is stopping")
|
||||
tc.logger.Info("TokenClient was cancelled and is stopping")
|
||||
return
|
||||
case <-sleeper:
|
||||
var tokenTTL time.Duration
|
||||
|
||||
err := backoff.WithContext(ctx, &backoff.InfiniteBackoff{
|
||||
Duration: 10 * time.Millisecond,
|
||||
MaxDuration: 5 * time.Second,
|
||||
MaxDuration: 10 * time.Second,
|
||||
Factor: 2.0,
|
||||
}, func(ctx context.Context) (bool, error) {
|
||||
var (
|
||||
err error
|
||||
token string
|
||||
)
|
||||
token, tokenTTL, err = tokenClient.fetchToken(ctx)
|
||||
|
||||
token, tokenTTL, err = tc.fetchToken(ctx)
|
||||
if err != nil {
|
||||
tokenClient.logger.Warning(fmt.Sprintf("Could not fetch token: %s\n", err))
|
||||
// We got an error. Swallow it and ask for retry.
|
||||
// We got an error. Log it, swallow it, and ask for retry by returning false.
|
||||
tc.logger.Error("TokenClient could not fetch short-lived service account token (will retry)", err,
|
||||
"serviceAccountName", tc.serviceAccountName)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
tokenClient.whatToDoWithToken(token, tokenTTL)
|
||||
// We got a token. Stop backing off.
|
||||
// We got a new token, so invoke the callback.
|
||||
tc.whatToDoWithToken(token, tokenTTL)
|
||||
// Stop backing off.
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// We were cancelled during our WithContext. We know it was not due to some other
|
||||
// error because our last argument to WithContext above never returns any errors.
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule ourselves to wake up in the future.
|
||||
time.AfterFunc(tokenTTL*4/5, func() {
|
||||
sleeper <- time.Now()
|
||||
@@ -103,19 +109,15 @@ func (tokenClient TokenClient) Start(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tokenClient TokenClient) fetchToken(ctx context.Context) (token string, ttl time.Duration, _ error) {
|
||||
tokenClient.logger.Debug(fmt.Sprintf("refreshing cache at time=%s\n", tokenClient.clock.Now().Format(time.RFC3339)))
|
||||
|
||||
tokenRequestInput := &authenticationv1.TokenRequest{
|
||||
Spec: authenticationv1.TokenRequestSpec{
|
||||
ExpirationSeconds: &tokenClient.expirationSeconds,
|
||||
func (tc TokenClient) fetchToken(ctx context.Context) (token string, ttl time.Duration, _ error) {
|
||||
tc.logger.Debug("TokenClient calling CreateToken to fetch a short-lived service account token")
|
||||
tokenResponse, err := tc.serviceAccountClient.CreateToken(ctx,
|
||||
tc.serviceAccountName,
|
||||
&authenticationv1.TokenRequest{
|
||||
Spec: authenticationv1.TokenRequestSpec{
|
||||
ExpirationSeconds: &tc.expirationSeconds,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokenResponse, err := tokenClient.serviceAccountClient.CreateToken(
|
||||
ctx,
|
||||
tokenClient.serviceAccountName,
|
||||
tokenRequestInput,
|
||||
metav1.CreateOptions{},
|
||||
)
|
||||
|
||||
@@ -124,10 +126,10 @@ func (tokenClient TokenClient) fetchToken(ctx context.Context) (token string, tt
|
||||
}
|
||||
|
||||
if tokenResponse == nil {
|
||||
return "", 0, errors.New("tokenRequest is nil after request")
|
||||
return "", 0, errors.New("got nil CreateToken response")
|
||||
}
|
||||
|
||||
return tokenResponse.Status.Token,
|
||||
tokenResponse.Status.ExpirationTimestamp.Sub(tokenClient.clock.Now()),
|
||||
tokenResponse.Status.ExpirationTimestamp.Sub(tc.clock.Now()),
|
||||
nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -24,8 +24,7 @@ import (
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals // just some test helper stuff here
|
||||
var (
|
||||
const (
|
||||
verb = "create"
|
||||
resource = "serviceaccounts/token"
|
||||
)
|
||||
@@ -93,6 +92,8 @@ func TestNew(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := New(
|
||||
tt.args.serviceAccountName,
|
||||
tt.args.serviceAccountClient,
|
||||
@@ -159,7 +160,7 @@ func TestFetchToken(t *testing.T) {
|
||||
expirationSeconds: 333,
|
||||
serviceAccountName: "service-account-name",
|
||||
expected: expected{
|
||||
errMessage: "tokenRequest is nil after request",
|
||||
errMessage: "got nil CreateToken response",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -167,6 +168,8 @@ func TestFetchToken(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClock := clocktesting.NewFakeClock(mockTime.Time)
|
||||
var log bytes.Buffer
|
||||
|
||||
@@ -206,68 +209,196 @@ func TestFetchToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_HappyPath(t *testing.T) {
|
||||
mockClient := fake.NewSimpleClientset()
|
||||
now := time.Now()
|
||||
var log bytes.Buffer
|
||||
func TestStart(t *testing.T) {
|
||||
type apiResponse struct {
|
||||
token string
|
||||
ttl time.Duration // how much in the future from the time of token response to set the expiration date
|
||||
err error
|
||||
}
|
||||
|
||||
type receivedToken struct {
|
||||
token string
|
||||
ttl time.Duration
|
||||
ttl time.Duration // expected ttl, within a fudge factor
|
||||
}
|
||||
|
||||
var receivedTokens []receivedToken
|
||||
|
||||
tokenClient := New(
|
||||
"service-account-name",
|
||||
mockClient.CoreV1().ServiceAccounts("any-namespace-works"),
|
||||
func(token string, ttl time.Duration) {
|
||||
t.Logf("received token %q with ttl %q", token, ttl)
|
||||
receivedTokens = append(receivedTokens, receivedToken{
|
||||
token: token,
|
||||
ttl: ttl,
|
||||
})
|
||||
},
|
||||
plog.TestLogger(t, &log),
|
||||
)
|
||||
|
||||
type reactionResponse struct {
|
||||
status authenticationv1.TokenRequestStatus
|
||||
err error
|
||||
type wanted struct {
|
||||
receivedTokens []receivedToken
|
||||
timeFudgeFactor time.Duration
|
||||
approxTimesBetweenAPIInvocations []time.Duration
|
||||
}
|
||||
|
||||
var reactionResponses []reactionResponse
|
||||
|
||||
for i := int64(0); i < 1000; i++ {
|
||||
ttl := time.Duration((1 + i) * 50 * int64(time.Millisecond))
|
||||
reactionResponses = append(reactionResponses, reactionResponse{
|
||||
status: authenticationv1.TokenRequestStatus{
|
||||
Token: fmt.Sprintf("token-%d-ttl-%s", i, ttl),
|
||||
ExpirationTimestamp: metav1.Time{Time: now.Add(ttl)},
|
||||
tests := []struct {
|
||||
name string
|
||||
apiResponses []apiResponse
|
||||
want *wanted
|
||||
}{
|
||||
{
|
||||
name: "several successful token requests",
|
||||
apiResponses: []apiResponse{
|
||||
{token: "t1", ttl: 200 * time.Millisecond},
|
||||
{token: "t2", ttl: 400 * time.Millisecond},
|
||||
{token: "t3", ttl: 300 * time.Millisecond},
|
||||
{token: "t4", ttl: time.Hour},
|
||||
},
|
||||
want: &wanted{
|
||||
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
|
||||
receivedTokens: []receivedToken{
|
||||
{token: "t1", ttl: 200 * time.Millisecond},
|
||||
{token: "t2", ttl: 400 * time.Millisecond},
|
||||
{token: "t3", ttl: 300 * time.Millisecond},
|
||||
{token: "t4", ttl: time.Hour},
|
||||
},
|
||||
approxTimesBetweenAPIInvocations: []time.Duration{
|
||||
160 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
|
||||
320 * time.Millisecond, // time between getting t2 and t3 (80% of t2's ttl)
|
||||
240 * time.Millisecond, // time between getting t4 and t4 (80% of t3's ttl)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some errors in the middle",
|
||||
apiResponses: []apiResponse{
|
||||
{token: "t1", ttl: 100 * time.Millisecond},
|
||||
{token: "t2", ttl: 200 * time.Millisecond},
|
||||
{err: errors.New("err1")},
|
||||
{err: errors.New("err2")},
|
||||
{err: errors.New("err3")},
|
||||
{err: errors.New("err4")},
|
||||
{err: errors.New("err5")},
|
||||
{err: errors.New("err6")},
|
||||
{err: errors.New("err7")},
|
||||
{token: "t3", ttl: 100 * time.Millisecond},
|
||||
{token: "t4", ttl: time.Hour},
|
||||
},
|
||||
want: &wanted{
|
||||
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
|
||||
receivedTokens: []receivedToken{
|
||||
{token: "t1", ttl: 100 * time.Millisecond},
|
||||
{token: "t2", ttl: 200 * time.Millisecond},
|
||||
{token: "t3", ttl: 100 * time.Millisecond},
|
||||
{token: "t4", ttl: time.Hour},
|
||||
},
|
||||
approxTimesBetweenAPIInvocations: []time.Duration{
|
||||
80 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
|
||||
160 * time.Millisecond, // time between getting t2 and err1 (80% of t2's ttl)
|
||||
10 * time.Millisecond, // time between getting err1 and err2 (1st step of exponential backoff)
|
||||
20 * time.Millisecond, // time between getting err2 and err3 (2nd step of exponential backoff)
|
||||
40 * time.Millisecond, // time between getting err3 and err4 (3rd step of exponential backoff)
|
||||
80 * time.Millisecond, // time between getting err4 and err5 (4th step of exponential backoff)
|
||||
160 * time.Millisecond, // time between getting err5 and err6 (5th step of exponential backoff)
|
||||
320 * time.Millisecond, // time between getting err6 and err7 (6th step of exponential backoff)
|
||||
640 * time.Millisecond, // time between getting err7 and t3 (7th step of exponential backoff)
|
||||
80 * time.Millisecond, // time between getting t3 and t4 (80% of t3's ttl)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "getting errors before successfully fetching the first token",
|
||||
apiResponses: []apiResponse{
|
||||
{err: errors.New("err1")},
|
||||
{err: errors.New("err2")},
|
||||
{err: errors.New("err3")},
|
||||
{err: errors.New("err4")},
|
||||
{token: "t1", ttl: 100 * time.Millisecond},
|
||||
{token: "t2", ttl: 200 * time.Millisecond},
|
||||
{token: "t3", ttl: time.Hour},
|
||||
},
|
||||
want: &wanted{
|
||||
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
|
||||
receivedTokens: []receivedToken{
|
||||
{token: "t1", ttl: 100 * time.Millisecond},
|
||||
{token: "t2", ttl: 200 * time.Millisecond},
|
||||
{token: "t3", ttl: time.Hour},
|
||||
},
|
||||
approxTimesBetweenAPIInvocations: []time.Duration{
|
||||
10 * time.Millisecond, // time between getting err1 and err2 (1st step of exponential backoff)
|
||||
20 * time.Millisecond, // time between getting err2 and err3 (2nd step of exponential backoff)
|
||||
40 * time.Millisecond, // time between getting err3 and err4 (3rd step of exponential backoff)
|
||||
80 * time.Millisecond, // time between getting err4 and t1 (4th step of exponential backoff)
|
||||
80 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
|
||||
160 * time.Millisecond, // time between getting t2 and t3 (80% of t2's ttl)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mockClient := fake.NewSimpleClientset()
|
||||
var logs bytes.Buffer
|
||||
|
||||
var mutex sync.Mutex
|
||||
// These variables are accessed by the reactor and by the callback function in the goroutine which is
|
||||
// running Start() below. But they are also accessed by this test's main goroutine to make assertions later.
|
||||
// Protect them with a mutex to make the data race detector happy.
|
||||
var receivedTokens []receivedToken
|
||||
var reactorCallTimestamps []time.Time
|
||||
reactorCallCount := 0
|
||||
|
||||
subject := New(
|
||||
"service-account-name",
|
||||
mockClient.CoreV1().ServiceAccounts("any-namespace-works"),
|
||||
func(token string, ttl time.Duration) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
t.Logf("received token %q with ttl %q", token, ttl)
|
||||
receivedTokens = append(receivedTokens, receivedToken{token: token, ttl: ttl})
|
||||
},
|
||||
plog.TestLogger(t, &logs),
|
||||
)
|
||||
|
||||
mockClient.PrependReactor(verb, resource, func(action coretesting.Action) (handled bool, ret runtime.Object, err error) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
require.Less(t, reactorCallCount, len(tt.apiResponses),
|
||||
"more TokenRequests were made than fake reactor responses were prepared in the test setup")
|
||||
response := &authenticationv1.TokenRequest{Status: authenticationv1.TokenRequestStatus{
|
||||
Token: tt.apiResponses[reactorCallCount].token,
|
||||
ExpirationTimestamp: metav1.NewTime(time.Now().Add(tt.apiResponses[reactorCallCount].ttl)),
|
||||
}}
|
||||
responseErr := tt.apiResponses[reactorCallCount].err
|
||||
reactorCallCount++
|
||||
reactorCallTimestamps = append(reactorCallTimestamps, time.Now())
|
||||
t.Logf("fake CreateToken API returning response %q at time %s", response.Status, time.Now())
|
||||
return true, response, responseErr
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
time.AfterFunc(4*time.Second, cancel) // cancel the context after a few seconds
|
||||
go subject.Start(ctx) // Start() should only return after the context is cancelled
|
||||
<-ctx.Done()
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// Should have used up all the reactor responses from the test table.
|
||||
require.Equal(t, reactorCallCount, len(tt.apiResponses))
|
||||
|
||||
// Should have got the expected callbacks for new tokens.
|
||||
require.Equal(t, len(tt.want.receivedTokens), len(receivedTokens))
|
||||
for i := range tt.want.receivedTokens {
|
||||
require.Equal(t, tt.want.receivedTokens[i].token, receivedTokens[i].token)
|
||||
require.InDelta(t,
|
||||
float64(tt.want.receivedTokens[i].ttl), float64(receivedTokens[i].ttl),
|
||||
float64(tt.want.timeFudgeFactor),
|
||||
)
|
||||
}
|
||||
|
||||
// Should have observed the appropriate amount of elapsed time in between each call to the CreateToken API.
|
||||
require.Equal(t, reactorCallCount-1, len(tt.want.approxTimesBetweenAPIInvocations), "wrong number of expected time deltas in test setup")
|
||||
for i := range reactorCallTimestamps {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
actualDelta := reactorCallTimestamps[i].Sub(reactorCallTimestamps[i-1])
|
||||
require.InDeltaf(t,
|
||||
tt.want.approxTimesBetweenAPIInvocations[i-1], actualDelta,
|
||||
float64(tt.want.timeFudgeFactor),
|
||||
"for API invocation %d", i,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
mockClient.PrependReactor(verb, resource, func(action coretesting.Action) (handled bool, ret runtime.Object, err error) {
|
||||
i := callCount
|
||||
callCount++
|
||||
response := &authenticationv1.TokenRequest{
|
||||
Status: reactionResponses[i].status,
|
||||
}
|
||||
return true, response, reactionResponses[i].err
|
||||
})
|
||||
|
||||
defer func() {
|
||||
expected := int((10 * time.Second) / (50 * time.Millisecond))
|
||||
require.GreaterOrEqual(t, len(receivedTokens), expected*9/10)
|
||||
require.LessOrEqual(t, len(receivedTokens), expected*11/10)
|
||||
//require.Equal(t, "some expected logs", log.String())
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
time.AfterFunc(10*time.Second, cancel)
|
||||
go tokenClient.Start(ctx)
|
||||
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user