Improve unit tests in tokenclient_test.go

Also fix a linter error and rename some new files.
This commit is contained in:
Ryan Richard
2023-11-30 13:29:52 -08:00
parent c439cc03a2
commit 5f4645d505
5 changed files with 217 additions and 82 deletions

View File

@@ -5,7 +5,6 @@ package tokenclient
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
@@ -58,43 +57,50 @@ func New(
return client
}
func (tokenClient TokenClient) Start(ctx context.Context) {
func (tc TokenClient) Start(ctx context.Context) {
sleeper := make(chan time.Time, 1)
// Make sure that the <-sleeper below gets run once immediately.
sleeper <- time.Now()
for {
select {
case <-ctx.Done():
tokenClient.logger.Info("TokenClient was cancelled and is stopping")
tc.logger.Info("TokenClient was cancelled and is stopping")
return
case <-sleeper:
var tokenTTL time.Duration
err := backoff.WithContext(ctx, &backoff.InfiniteBackoff{
Duration: 10 * time.Millisecond,
MaxDuration: 5 * time.Second,
MaxDuration: 10 * time.Second,
Factor: 2.0,
}, func(ctx context.Context) (bool, error) {
var (
err error
token string
)
token, tokenTTL, err = tokenClient.fetchToken(ctx)
token, tokenTTL, err = tc.fetchToken(ctx)
if err != nil {
tokenClient.logger.Warning(fmt.Sprintf("Could not fetch token: %s\n", err))
// We got an error. Swallow it and ask for retry.
// We got an error. Log it, swallow it, and ask for retry by returning false.
tc.logger.Error("TokenClient could not fetch short-lived service account token (will retry)", err,
"serviceAccountName", tc.serviceAccountName)
return false, nil
}
tokenClient.whatToDoWithToken(token, tokenTTL)
// We got a token. Stop backing off.
// We got a new token, so invoke the callback.
tc.whatToDoWithToken(token, tokenTTL)
// Stop backing off.
return true, nil
})
if err != nil {
// We were cancelled during our WithContext. We know it was not due to some other
// error because our last argument to WithContext above never returns any errors.
return
}
// Schedule ourselves to wake up in the future.
time.AfterFunc(tokenTTL*4/5, func() {
sleeper <- time.Now()
@@ -103,19 +109,15 @@ func (tokenClient TokenClient) Start(ctx context.Context) {
}
}
func (tokenClient TokenClient) fetchToken(ctx context.Context) (token string, ttl time.Duration, _ error) {
tokenClient.logger.Debug(fmt.Sprintf("refreshing cache at time=%s\n", tokenClient.clock.Now().Format(time.RFC3339)))
tokenRequestInput := &authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{
ExpirationSeconds: &tokenClient.expirationSeconds,
func (tc TokenClient) fetchToken(ctx context.Context) (token string, ttl time.Duration, _ error) {
tc.logger.Debug("TokenClient calling CreateToken to fetch a short-lived service account token")
tokenResponse, err := tc.serviceAccountClient.CreateToken(ctx,
tc.serviceAccountName,
&authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{
ExpirationSeconds: &tc.expirationSeconds,
},
},
}
tokenResponse, err := tokenClient.serviceAccountClient.CreateToken(
ctx,
tokenClient.serviceAccountName,
tokenRequestInput,
metav1.CreateOptions{},
)
@@ -124,10 +126,10 @@ func (tokenClient TokenClient) fetchToken(ctx context.Context) (token string, tt
}
if tokenResponse == nil {
return "", 0, errors.New("tokenRequest is nil after request")
return "", 0, errors.New("got nil CreateToken response")
}
return tokenResponse.Status.Token,
tokenResponse.Status.ExpirationTimestamp.Sub(tokenClient.clock.Now()),
tokenResponse.Status.ExpirationTimestamp.Sub(tc.clock.Now()),
nil
}

View File

@@ -7,7 +7,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
@@ -24,8 +24,7 @@ import (
"go.pinniped.dev/internal/plog"
)
//nolint:gochecknoglobals // just some test helper stuff here
var (
const (
verb = "create"
resource = "serviceaccounts/token"
)
@@ -93,6 +92,8 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
actual := New(
tt.args.serviceAccountName,
tt.args.serviceAccountClient,
@@ -159,7 +160,7 @@ func TestFetchToken(t *testing.T) {
expirationSeconds: 333,
serviceAccountName: "service-account-name",
expected: expected{
errMessage: "tokenRequest is nil after request",
errMessage: "got nil CreateToken response",
},
},
}
@@ -167,6 +168,8 @@ func TestFetchToken(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mockClock := clocktesting.NewFakeClock(mockTime.Time)
var log bytes.Buffer
@@ -206,68 +209,196 @@ func TestFetchToken(t *testing.T) {
}
}
func TestStart_HappyPath(t *testing.T) {
mockClient := fake.NewSimpleClientset()
now := time.Now()
var log bytes.Buffer
func TestStart(t *testing.T) {
type apiResponse struct {
token string
ttl time.Duration // how much in the future from the time of token response to set the expiration date
err error
}
type receivedToken struct {
token string
ttl time.Duration
ttl time.Duration // expected ttl, within a fudge factor
}
var receivedTokens []receivedToken
tokenClient := New(
"service-account-name",
mockClient.CoreV1().ServiceAccounts("any-namespace-works"),
func(token string, ttl time.Duration) {
t.Logf("received token %q with ttl %q", token, ttl)
receivedTokens = append(receivedTokens, receivedToken{
token: token,
ttl: ttl,
})
},
plog.TestLogger(t, &log),
)
type reactionResponse struct {
status authenticationv1.TokenRequestStatus
err error
type wanted struct {
receivedTokens []receivedToken
timeFudgeFactor time.Duration
approxTimesBetweenAPIInvocations []time.Duration
}
var reactionResponses []reactionResponse
for i := int64(0); i < 1000; i++ {
ttl := time.Duration((1 + i) * 50 * int64(time.Millisecond))
reactionResponses = append(reactionResponses, reactionResponse{
status: authenticationv1.TokenRequestStatus{
Token: fmt.Sprintf("token-%d-ttl-%s", i, ttl),
ExpirationTimestamp: metav1.Time{Time: now.Add(ttl)},
tests := []struct {
name string
apiResponses []apiResponse
want *wanted
}{
{
name: "several successful token requests",
apiResponses: []apiResponse{
{token: "t1", ttl: 200 * time.Millisecond},
{token: "t2", ttl: 400 * time.Millisecond},
{token: "t3", ttl: 300 * time.Millisecond},
{token: "t4", ttl: time.Hour},
},
want: &wanted{
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
receivedTokens: []receivedToken{
{token: "t1", ttl: 200 * time.Millisecond},
{token: "t2", ttl: 400 * time.Millisecond},
{token: "t3", ttl: 300 * time.Millisecond},
{token: "t4", ttl: time.Hour},
},
approxTimesBetweenAPIInvocations: []time.Duration{
160 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
320 * time.Millisecond, // time between getting t2 and t3 (80% of t2's ttl)
240 * time.Millisecond, // time between getting t4 and t4 (80% of t3's ttl)
},
},
},
{
name: "some errors in the middle",
apiResponses: []apiResponse{
{token: "t1", ttl: 100 * time.Millisecond},
{token: "t2", ttl: 200 * time.Millisecond},
{err: errors.New("err1")},
{err: errors.New("err2")},
{err: errors.New("err3")},
{err: errors.New("err4")},
{err: errors.New("err5")},
{err: errors.New("err6")},
{err: errors.New("err7")},
{token: "t3", ttl: 100 * time.Millisecond},
{token: "t4", ttl: time.Hour},
},
want: &wanted{
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
receivedTokens: []receivedToken{
{token: "t1", ttl: 100 * time.Millisecond},
{token: "t2", ttl: 200 * time.Millisecond},
{token: "t3", ttl: 100 * time.Millisecond},
{token: "t4", ttl: time.Hour},
},
approxTimesBetweenAPIInvocations: []time.Duration{
80 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
160 * time.Millisecond, // time between getting t2 and err1 (80% of t2's ttl)
10 * time.Millisecond, // time between getting err1 and err2 (1st step of exponential backoff)
20 * time.Millisecond, // time between getting err2 and err3 (2nd step of exponential backoff)
40 * time.Millisecond, // time between getting err3 and err4 (3rd step of exponential backoff)
80 * time.Millisecond, // time between getting err4 and err5 (4th step of exponential backoff)
160 * time.Millisecond, // time between getting err5 and err6 (5th step of exponential backoff)
320 * time.Millisecond, // time between getting err6 and err7 (6th step of exponential backoff)
640 * time.Millisecond, // time between getting err7 and t3 (7th step of exponential backoff)
80 * time.Millisecond, // time between getting t3 and t4 (80% of t3's ttl)
},
},
},
{
name: "getting errors before successfully fetching the first token",
apiResponses: []apiResponse{
{err: errors.New("err1")},
{err: errors.New("err2")},
{err: errors.New("err3")},
{err: errors.New("err4")},
{token: "t1", ttl: 100 * time.Millisecond},
{token: "t2", ttl: 200 * time.Millisecond},
{token: "t3", ttl: time.Hour},
},
want: &wanted{
timeFudgeFactor: 30 * time.Millisecond, // lots of fudge for busy CI workers
receivedTokens: []receivedToken{
{token: "t1", ttl: 100 * time.Millisecond},
{token: "t2", ttl: 200 * time.Millisecond},
{token: "t3", ttl: time.Hour},
},
approxTimesBetweenAPIInvocations: []time.Duration{
10 * time.Millisecond, // time between getting err1 and err2 (1st step of exponential backoff)
20 * time.Millisecond, // time between getting err2 and err3 (2nd step of exponential backoff)
40 * time.Millisecond, // time between getting err3 and err4 (3rd step of exponential backoff)
80 * time.Millisecond, // time between getting err4 and t1 (4th step of exponential backoff)
80 * time.Millisecond, // time between getting t1 and t2 (80% of t1's ttl)
160 * time.Millisecond, // time between getting t2 and t3 (80% of t2's ttl)
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mockClient := fake.NewSimpleClientset()
var logs bytes.Buffer
var mutex sync.Mutex
// These variables are accessed by the reactor and by the callback function in the goroutine which is
// running Start() below. But they are also accessed by this test's main goroutine to make assertions later.
// Protect them with a mutex to make the data race detector happy.
var receivedTokens []receivedToken
var reactorCallTimestamps []time.Time
reactorCallCount := 0
subject := New(
"service-account-name",
mockClient.CoreV1().ServiceAccounts("any-namespace-works"),
func(token string, ttl time.Duration) {
mutex.Lock()
defer mutex.Unlock()
t.Logf("received token %q with ttl %q", token, ttl)
receivedTokens = append(receivedTokens, receivedToken{token: token, ttl: ttl})
},
plog.TestLogger(t, &logs),
)
mockClient.PrependReactor(verb, resource, func(action coretesting.Action) (handled bool, ret runtime.Object, err error) {
mutex.Lock()
defer mutex.Unlock()
require.Less(t, reactorCallCount, len(tt.apiResponses),
"more TokenRequests were made than fake reactor responses were prepared in the test setup")
response := &authenticationv1.TokenRequest{Status: authenticationv1.TokenRequestStatus{
Token: tt.apiResponses[reactorCallCount].token,
ExpirationTimestamp: metav1.NewTime(time.Now().Add(tt.apiResponses[reactorCallCount].ttl)),
}}
responseErr := tt.apiResponses[reactorCallCount].err
reactorCallCount++
reactorCallTimestamps = append(reactorCallTimestamps, time.Now())
t.Logf("fake CreateToken API returning response %q at time %s", response.Status, time.Now())
return true, response, responseErr
})
ctx, cancel := context.WithCancel(context.Background())
time.AfterFunc(4*time.Second, cancel) // cancel the context after a few seconds
go subject.Start(ctx) // Start() should only return after the context is cancelled
<-ctx.Done()
mutex.Lock()
defer mutex.Unlock()
// Should have used up all the reactor responses from the test table.
require.Equal(t, reactorCallCount, len(tt.apiResponses))
// Should have got the expected callbacks for new tokens.
require.Equal(t, len(tt.want.receivedTokens), len(receivedTokens))
for i := range tt.want.receivedTokens {
require.Equal(t, tt.want.receivedTokens[i].token, receivedTokens[i].token)
require.InDelta(t,
float64(tt.want.receivedTokens[i].ttl), float64(receivedTokens[i].ttl),
float64(tt.want.timeFudgeFactor),
)
}
// Should have observed the appropriate amount of elapsed time in between each call to the CreateToken API.
require.Equal(t, reactorCallCount-1, len(tt.want.approxTimesBetweenAPIInvocations), "wrong number of expected time deltas in test setup")
for i := range reactorCallTimestamps {
if i == 0 {
continue
}
actualDelta := reactorCallTimestamps[i].Sub(reactorCallTimestamps[i-1])
require.InDeltaf(t,
tt.want.approxTimesBetweenAPIInvocations[i-1], actualDelta,
float64(tt.want.timeFudgeFactor),
"for API invocation %d", i,
)
}
})
}
callCount := 0
mockClient.PrependReactor(verb, resource, func(action coretesting.Action) (handled bool, ret runtime.Object, err error) {
i := callCount
callCount++
response := &authenticationv1.TokenRequest{
Status: reactionResponses[i].status,
}
return true, response, reactionResponses[i].err
})
defer func() {
expected := int((10 * time.Second) / (50 * time.Millisecond))
require.GreaterOrEqual(t, len(receivedTokens), expected*9/10)
require.LessOrEqual(t, len(receivedTokens), expected*11/10)
//require.Equal(t, "some expected logs", log.String())
}()
ctx, cancel := context.WithCancel(context.Background())
time.AfterFunc(10*time.Second, cancel)
go tokenClient.Start(ctx)
<-ctx.Done()
}