switch from Update() to Patch()

Signed-off-by: Steve Kriss <steve@heptio.com>
This commit is contained in:
Steve Kriss
2017-12-11 14:10:52 -08:00
parent 6d5eeb21f5
commit 4aea9b9a2c
9 changed files with 547 additions and 283 deletions

View File

@@ -19,6 +19,7 @@ package controller
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@@ -29,8 +30,10 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
kuberrs "k8s.io/apimachinery/pkg/util/errors" kuberrs "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/workqueue" "k8s.io/client-go/util/workqueue"
@@ -60,7 +63,7 @@ type backupController struct {
syncHandler func(backupName string) error syncHandler func(backupName string) error
queue workqueue.RateLimitingInterface queue workqueue.RateLimitingInterface
clock clock.Clock clock clock.Clock
logger *logrus.Logger logger logrus.FieldLogger
pluginManager plugin.Manager pluginManager plugin.Manager
} }
@@ -71,7 +74,7 @@ func NewBackupController(
backupService cloudprovider.BackupService, backupService cloudprovider.BackupService,
bucket string, bucket string,
pvProviderExists bool, pvProviderExists bool,
logger *logrus.Logger, logger logrus.FieldLogger,
pluginManager plugin.Manager, pluginManager plugin.Manager,
) Interface { ) Interface {
c := &backupController{ c := &backupController{
@@ -223,6 +226,8 @@ func (controller *backupController) processBackup(key string) error {
} }
logContext.Debug("Cloning backup") logContext.Debug("Cloning backup")
// store ref to original for creating patch
original := backup
// don't modify items in the cache // don't modify items in the cache
backup = backup.DeepCopy() backup = backup.DeepCopy()
@@ -242,11 +247,13 @@ func (controller *backupController) processBackup(key string) error {
} }
// update status // update status
updatedBackup, err := controller.client.Backups(ns).Update(backup) updatedBackup, err := patchBackup(original, backup, controller.client)
if err != nil { if err != nil {
return errors.Wrapf(err, "error updating Backup status to %s", backup.Status.Phase) return errors.Wrapf(err, "error updating Backup status to %s", backup.Status.Phase)
} }
backup = updatedBackup // store ref to just-updated item for creating patch
original = updatedBackup
backup = updatedBackup.DeepCopy()
if backup.Status.Phase == api.BackupPhaseFailedValidation { if backup.Status.Phase == api.BackupPhaseFailedValidation {
return nil return nil
@@ -260,13 +267,37 @@ func (controller *backupController) processBackup(key string) error {
} }
logContext.Debug("Updating backup's final status") logContext.Debug("Updating backup's final status")
if _, err = controller.client.Backups(ns).Update(backup); err != nil { if _, err := patchBackup(original, backup, controller.client); err != nil {
logContext.WithError(err).Error("error updating backup's final status") logContext.WithError(err).Error("error updating backup's final status")
} }
return nil return nil
} }
func patchBackup(original, updated *api.Backup, client arkv1client.BackupsGetter) (*api.Backup, error) {
origBytes, err := json.Marshal(original)
if err != nil {
return nil, errors.Wrap(err, "error marshalling original backup")
}
updatedBytes, err := json.Marshal(updated)
if err != nil {
return nil, errors.Wrap(err, "error marshalling updated backup")
}
patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Backup{})
if err != nil {
return nil, errors.Wrap(err, "error creating two-way merge patch for backup")
}
res, err := client.Backups(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes)
if err != nil {
return nil, errors.Wrap(err, "error patching backup")
}
return res, nil
}
func (controller *backupController) getValidationErrors(itm *api.Backup) []string { func (controller *backupController) getValidationErrors(itm *api.Backup) []string {
var validationErrors []string var validationErrors []string

View File

@@ -17,6 +17,7 @@ limitations under the License.
package controller package controller
import ( import (
"encoding/json"
"io" "io"
"testing" "testing"
"time" "time"
@@ -25,7 +26,6 @@ import (
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
core "k8s.io/client-go/testing" core "k8s.io/client-go/testing"
testlogger "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -34,10 +34,10 @@ import (
"github.com/heptio/ark/pkg/backup" "github.com/heptio/ark/pkg/backup"
"github.com/heptio/ark/pkg/cloudprovider" "github.com/heptio/ark/pkg/cloudprovider"
"github.com/heptio/ark/pkg/generated/clientset/versioned/fake" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake"
"github.com/heptio/ark/pkg/generated/clientset/versioned/scheme"
informers "github.com/heptio/ark/pkg/generated/informers/externalversions" informers "github.com/heptio/ark/pkg/generated/informers/externalversions"
"github.com/heptio/ark/pkg/restore" "github.com/heptio/ark/pkg/restore"
. "github.com/heptio/ark/pkg/util/test" "github.com/heptio/ark/pkg/util/collections"
arktest "github.com/heptio/ark/pkg/util/test"
) )
type fakeBackupper struct { type fakeBackupper struct {
@@ -56,7 +56,7 @@ func TestProcessBackup(t *testing.T) {
expectError bool expectError bool
expectedIncludes []string expectedIncludes []string
expectedExcludes []string expectedExcludes []string
backup *TestBackup backup *arktest.TestBackup
expectBackup bool expectBackup bool
allowSnapshots bool allowSnapshots bool
}{ }{
@@ -73,49 +73,49 @@ func TestProcessBackup(t *testing.T) {
{ {
name: "do not process phase FailedValidation", name: "do not process phase FailedValidation",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailedValidation), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailedValidation),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "do not process phase InProgress", name: "do not process phase InProgress",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseInProgress), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseInProgress),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "do not process phase Completed", name: "do not process phase Completed",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseCompleted), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseCompleted),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "do not process phase Failed", name: "do not process phase Failed",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailed), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailed),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "do not process phase other", name: "do not process phase other",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase("arg"), backup: arktest.NewTestBackup().WithName("backup1").WithPhase("arg"),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "invalid included/excluded resources fails validation", name: "invalid included/excluded resources fails validation",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("foo").WithExcludedResources("foo"), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("foo").WithExcludedResources("foo"),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "invalid included/excluded namespaces fails validation", name: "invalid included/excluded namespaces fails validation",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("foo").WithExcludedNamespaces("foo"), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("foo").WithExcludedNamespaces("foo"),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "make sure specified included and excluded resources are honored", name: "make sure specified included and excluded resources are honored",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("i", "j").WithExcludedResources("k", "l"), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("i", "j").WithExcludedResources("k", "l"),
expectedIncludes: []string{"i", "j"}, expectedIncludes: []string{"i", "j"},
expectedExcludes: []string{"k", "l"}, expectedExcludes: []string{"k", "l"},
expectBackup: true, expectBackup: true,
@@ -123,25 +123,25 @@ func TestProcessBackup(t *testing.T) {
{ {
name: "if includednamespaces are specified, don't default to *", name: "if includednamespaces are specified, don't default to *",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("ns-1"), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("ns-1"),
expectBackup: true, expectBackup: true,
}, },
{ {
name: "ttl", name: "ttl",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithTTL(10 * time.Minute), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithTTL(10 * time.Minute),
expectBackup: true, expectBackup: true,
}, },
{ {
name: "backup with SnapshotVolumes when allowSnapshots=false fails validation", name: "backup with SnapshotVolumes when allowSnapshots=false fails validation",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true),
expectBackup: false, expectBackup: false,
}, },
{ {
name: "backup with SnapshotVolumes when allowSnapshots=true gets executed", name: "backup with SnapshotVolumes when allowSnapshots=true gets executed",
key: "heptio-ark/backup1", key: "heptio-ark/backup1",
backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true),
allowSnapshots: true, allowSnapshots: true,
expectBackup: true, expectBackup: true,
}, },
@@ -152,9 +152,9 @@ func TestProcessBackup(t *testing.T) {
var ( var (
client = fake.NewSimpleClientset() client = fake.NewSimpleClientset()
backupper = &fakeBackupper{} backupper = &fakeBackupper{}
cloudBackups = &BackupService{} cloudBackups = &arktest.BackupService{}
sharedInformers = informers.NewSharedInformerFactory(client, 0) sharedInformers = informers.NewSharedInformerFactory(client, 0)
logger, _ = testlogger.NewNullLogger() logger = arktest.NewLogger()
pluginManager = &MockManager{} pluginManager = &MockManager{}
) )
@@ -182,9 +182,7 @@ func TestProcessBackup(t *testing.T) {
} }
// set up a Backup object to represent what we expect to be passed to backupper.Backup() // set up a Backup object to represent what we expect to be passed to backupper.Backup()
copy, err := scheme.Scheme.Copy(test.backup.Backup) backup := test.backup.DeepCopy()
assert.NoError(t, err, "copy error")
backup := copy.(*v1.Backup)
backup.Spec.IncludedResources = test.expectedIncludes backup.Spec.IncludedResources = test.expectedIncludes
backup.Spec.ExcludedResources = test.expectedExcludes backup.Spec.ExcludedResources = test.expectedExcludes
backup.Spec.IncludedNamespaces = test.backup.Spec.IncludedNamespaces backup.Spec.IncludedNamespaces = test.backup.Spec.IncludedNamespaces
@@ -200,16 +198,35 @@ func TestProcessBackup(t *testing.T) {
pluginManager.On("CloseBackupItemActions", backup.Name).Return(nil) pluginManager.On("CloseBackupItemActions", backup.Name).Return(nil)
} }
// this is necessary so the Update() call returns the appropriate object // this is necessary so the Patch() call returns the appropriate object
client.PrependReactor("update", "backups", func(action core.Action) (bool, runtime.Object, error) { client.PrependReactor("patch", "backups", func(action core.Action) (bool, runtime.Object, error) {
obj := action.(core.UpdateAction).GetObject() if test.backup == nil {
// need to deep copy so we can test the backup state for each call to update return true, nil, nil
copy, err := scheme.Scheme.DeepCopy(obj) }
if err != nil {
patch := action.(core.PatchAction).GetPatch()
patchMap := make(map[string]interface{})
if err := json.Unmarshal(patch, &patchMap); err != nil {
t.Logf("error unmarshalling patch: %s\n", err)
return false, nil, err return false, nil, err
} }
ret := copy.(runtime.Object)
return true, ret, nil phase, err := collections.GetString(patchMap, "status.phase")
if err != nil {
t.Logf("error getting status.phase: %s\n", err)
return false, nil, err
}
res := test.backup.DeepCopy()
// these are the fields that we expect to be set by
// the controller
res.Status.Version = 1
res.Status.Expiration.Time = expiration
res.Status.Phase = v1.BackupPhase(phase)
return true, res, nil
}) })
// method under test // method under test
@@ -227,41 +244,41 @@ func TestProcessBackup(t *testing.T) {
return return
} }
expectedActions := []core.Action{ actions := client.Actions()
core.NewUpdateAction( require.Equal(t, 2, len(actions))
v1.SchemeGroupVersion.WithResource("backups"),
v1.DefaultNamespace,
NewTestBackup().
WithName(test.backup.Name).
WithPhase(v1.BackupPhaseInProgress).
WithIncludedResources(test.expectedIncludes...).
WithExcludedResources(test.expectedExcludes...).
WithIncludedNamespaces(test.backup.Spec.IncludedNamespaces...).
WithTTL(test.backup.Spec.TTL.Duration).
WithSnapshotVolumesPointer(test.backup.Spec.SnapshotVolumes).
WithExpiration(expiration).
WithVersion(1).
Backup,
),
core.NewUpdateAction( // validate Patch call 1 (setting version, expiration, and phase)
v1.SchemeGroupVersion.WithResource("backups"), patchAction, ok := actions[0].(core.PatchAction)
v1.DefaultNamespace, require.True(t, ok, "action is not a PatchAction")
NewTestBackup().
WithName(test.backup.Name). patch := make(map[string]interface{})
WithPhase(v1.BackupPhaseCompleted). require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
WithIncludedResources(test.expectedIncludes...).
WithExcludedResources(test.expectedExcludes...). assert.Equal(t, 1, len(patch), "patch has wrong number of keys")
WithIncludedNamespaces(test.backup.Spec.IncludedNamespaces...).
WithTTL(test.backup.Spec.TTL.Duration). expectedStatusKeys := 2
WithSnapshotVolumesPointer(test.backup.Spec.SnapshotVolumes). if test.backup.Spec.TTL.Duration > 0 {
WithExpiration(expiration). assert.True(t, collections.HasKeyAndVal(patch, "status.expiration", expiration.UTC().Format(time.RFC3339)), "patch's status.expiration does not match")
WithVersion(1). expectedStatusKeys = 3
Backup,
),
} }
assert.Equal(t, expectedActions, client.Actions()) assert.True(t, collections.HasKeyAndVal(patch, "status.version", float64(1)))
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(v1.BackupPhaseInProgress)), "patch's status.phase does not match")
res, _ := collections.GetMap(patch, "status")
assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys")
// validate Patch call 2 (setting phase)
patchAction, ok = actions[1].(core.PatchAction)
require.True(t, ok, "action is not a PatchAction")
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
assert.Equal(t, 1, len(patch), "patch has wrong number of keys")
res, _ = collections.GetMap(patch, "status")
assert.Equal(t, 1, len(res), "patch's status has the wrong number of keys")
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(v1.BackupPhaseCompleted)), "patch's status.phase does not match")
}) })
} }
} }

View File

@@ -18,6 +18,7 @@ package controller
import ( import (
"context" "context"
"encoding/json"
"sync" "sync"
"time" "time"
@@ -27,7 +28,9 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/workqueue" "k8s.io/client-go/util/workqueue"
@@ -220,7 +223,7 @@ func (c *downloadRequestController) generatePreSignedURL(downloadRequest *v1.Dow
update.Status.Phase = v1.DownloadRequestPhaseProcessed update.Status.Phase = v1.DownloadRequestPhaseProcessed
update.Status.Expiration = metav1.NewTime(c.clock.Now().Add(signedURLTTL)) update.Status.Expiration = metav1.NewTime(c.clock.Now().Add(signedURLTTL))
_, err = c.downloadRequestClient.DownloadRequests(update.Namespace).Update(update) _, err = patchDownloadRequest(downloadRequest, update, c.downloadRequestClient)
return errors.WithStack(err) return errors.WithStack(err)
} }
@@ -256,3 +259,27 @@ func (c *downloadRequestController) resync() {
c.queue.Add(key) c.queue.Add(key)
} }
} }
func patchDownloadRequest(original, updated *v1.DownloadRequest, client arkv1client.DownloadRequestsGetter) (*v1.DownloadRequest, error) {
origBytes, err := json.Marshal(original)
if err != nil {
return nil, errors.Wrap(err, "error marshalling original download request")
}
updatedBytes, err := json.Marshal(updated)
if err != nil {
return nil, errors.Wrap(err, "error marshalling updated download request")
}
patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, v1.DownloadRequest{})
if err != nil {
return nil, errors.Wrap(err, "error creating two-way merge patch for download request")
}
res, err := client.DownloadRequests(v1.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes)
if err != nil {
return nil, errors.Wrap(err, "error patching download request")
}
return res, nil
}

View File

@@ -17,6 +17,7 @@ limitations under the License.
package controller package controller
import ( import (
"encoding/json"
"testing" "testing"
"time" "time"
@@ -25,12 +26,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
core "k8s.io/client-go/testing" core "k8s.io/client-go/testing"
"github.com/heptio/ark/pkg/apis/ark/v1" "github.com/heptio/ark/pkg/apis/ark/v1"
"github.com/heptio/ark/pkg/generated/clientset/versioned/fake" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake"
informers "github.com/heptio/ark/pkg/generated/informers/externalversions" informers "github.com/heptio/ark/pkg/generated/informers/externalversions"
"github.com/heptio/ark/pkg/util/collections"
"github.com/heptio/ark/pkg/util/test" "github.com/heptio/ark/pkg/util/test"
) )
@@ -111,37 +112,28 @@ func TestProcessDownloadRequest(t *testing.T) {
logger, logger,
).(*downloadRequestController) ).(*downloadRequestController)
var downloadRequest *v1.DownloadRequest
if tc.expectedPhase == v1.DownloadRequestPhaseProcessed { if tc.expectedPhase == v1.DownloadRequestPhaseProcessed {
target := v1.DownloadTarget{ target := v1.DownloadTarget{
Kind: tc.targetKind, Kind: tc.targetKind,
Name: tc.targetName, Name: tc.targetName,
} }
downloadRequestsInformer.Informer().GetStore().Add( downloadRequest = &v1.DownloadRequest{
&v1.DownloadRequest{ ObjectMeta: metav1.ObjectMeta{
ObjectMeta: metav1.ObjectMeta{ Namespace: v1.DefaultNamespace,
Namespace: v1.DefaultNamespace, Name: "dr1",
Name: "dr1",
},
Spec: v1.DownloadRequestSpec{
Target: target,
},
}, },
) Spec: v1.DownloadRequestSpec{
Target: target,
},
}
downloadRequestsInformer.Informer().GetStore().Add(downloadRequest)
backupService.On("CreateSignedURL", target, "bucket", 10*time.Minute).Return("signedURL", nil) backupService.On("CreateSignedURL", target, "bucket", 10*time.Minute).Return("signedURL", nil)
} }
var updatedRequest *v1.DownloadRequest
client.PrependReactor("update", "downloadrequests", func(action core.Action) (bool, runtime.Object, error) {
obj := action.(core.UpdateAction).GetObject()
r, ok := obj.(*v1.DownloadRequest)
require.True(t, ok)
updatedRequest = r
return true, obj, nil
})
// method under test // method under test
err := c.processDownloadRequest(tc.key) err := c.processDownloadRequest(tc.key)
@@ -152,16 +144,37 @@ func TestProcessDownloadRequest(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var ( actions := client.Actions()
updatedPhase v1.DownloadRequestPhase
updatedURL string // if we don't expect a phase update, this means
) // we don't expect any actions to take place
if updatedRequest != nil { if tc.expectedPhase == "" {
updatedPhase = updatedRequest.Status.Phase require.Equal(t, 0, len(actions))
updatedURL = updatedRequest.Status.DownloadURL return
} }
assert.Equal(t, tc.expectedPhase, updatedPhase)
assert.Equal(t, tc.expectedURL, updatedURL) // otherwise, we should get exactly 1 patch
require.Equal(t, 1, len(actions))
patchAction, ok := actions[0].(core.PatchAction)
require.True(t, ok, "action is not a PatchAction")
patch := make(map[string]interface{})
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
// check the URL
assert.True(t, collections.HasKeyAndVal(patch, "status.downloadURL", tc.expectedURL), "patch's status.downloadURL does not match")
// check the Phase
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(tc.expectedPhase)), "patch's status.phase does not match")
// check that Expiration exists
// TODO pass a fake clock to the controller and verify
// the expiration value
assert.True(t, collections.Exists(patch, "status.expiration"), "patch's status.expiration does not exist")
// we expect 3 total updates.
res, _ := collections.GetMap(patch, "status")
assert.Equal(t, 3, len(res), "patch's status has the wrong number of keys")
}) })
} }
} }

View File

@@ -31,7 +31,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/workqueue" "k8s.io/client-go/util/workqueue"
@@ -230,6 +232,8 @@ func (controller *restoreController) processRestore(key string) error {
} }
logContext.Debug("Cloning Restore") logContext.Debug("Cloning Restore")
// store ref to original for creating patch
original := restore
// don't modify items in the cache // don't modify items in the cache
restore = restore.DeepCopy() restore = restore.DeepCopy()
@@ -248,11 +252,13 @@ func (controller *restoreController) processRestore(key string) error {
} }
// update status // update status
updatedRestore, err := controller.restoreClient.Restores(ns).Update(restore) updatedRestore, err := patchRestore(original, restore, controller.restoreClient)
if err != nil { if err != nil {
return errors.Wrapf(err, "error updating Restore phase to %s", restore.Status.Phase) return errors.Wrapf(err, "error updating Restore phase to %s", restore.Status.Phase)
} }
restore = updatedRestore // store ref to just-updated item for creating patch
original = updatedRestore
restore = updatedRestore.DeepCopy()
if restore.Status.Phase == api.RestorePhaseFailedValidation { if restore.Status.Phase == api.RestorePhaseFailedValidation {
return nil return nil
@@ -276,7 +282,7 @@ func (controller *restoreController) processRestore(key string) error {
restore.Status.Phase = api.RestorePhaseCompleted restore.Status.Phase = api.RestorePhaseCompleted
logContext.Debug("Updating Restore final status") logContext.Debug("Updating Restore final status")
if _, err = controller.restoreClient.Restores(ns).Update(restore); err != nil { if _, err = patchRestore(original, restore, controller.restoreClient); err != nil {
logContext.WithError(errors.WithStack(err)).Info("Error updating Restore final status") logContext.WithError(errors.WithStack(err)).Info("Error updating Restore final status")
} }
@@ -472,3 +478,27 @@ func downloadToTempFile(backupName string, backupService cloudprovider.BackupSer
return file, nil return file, nil
} }
func patchRestore(original, updated *api.Restore, client arkv1client.RestoresGetter) (*api.Restore, error) {
origBytes, err := json.Marshal(original)
if err != nil {
return nil, errors.Wrap(err, "error marshalling original restore")
}
updatedBytes, err := json.Marshal(updated)
if err != nil {
return nil, errors.Wrap(err, "error marshalling updated restore")
}
patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Restore{})
if err != nil {
return nil, errors.Wrap(err, "error creating two-way merge patch for restore")
}
res, err := client.Restores(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes)
if err != nil {
return nil, errors.Wrap(err, "error patching restore")
}
return res, nil
}

View File

@@ -18,6 +18,7 @@ package controller
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
@@ -25,9 +26,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/scheme"
core "k8s.io/client-go/testing" core "k8s.io/client-go/testing"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
@@ -35,6 +36,7 @@ import (
"github.com/heptio/ark/pkg/generated/clientset/versioned/fake" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake"
informers "github.com/heptio/ark/pkg/generated/informers/externalversions" informers "github.com/heptio/ark/pkg/generated/informers/externalversions"
"github.com/heptio/ark/pkg/restore" "github.com/heptio/ark/pkg/restore"
"github.com/heptio/ark/pkg/util/collections"
arktest "github.com/heptio/ark/pkg/util/test" arktest "github.com/heptio/ark/pkg/util/test"
) )
@@ -120,7 +122,9 @@ func TestProcessRestore(t *testing.T) {
restorerError error restorerError error
allowRestoreSnapshots bool allowRestoreSnapshots bool
expectedErr bool expectedErr bool
expectedRestoreUpdates []*api.Restore expectedPhase string
expectedValidationErrors []string
expectedRestoreErrors int
expectedRestorerCall *api.Restore expectedRestorerCall *api.Restore
backupServiceGetBackupError error backupServiceGetBackupError error
uploadLogError error uploadLogError error
@@ -151,73 +155,53 @@ func TestProcessRestore(t *testing.T) {
expectedErr: false, expectedErr: false,
}, },
{ {
name: "restore with both namespace in both includedNamespaces and excludedNamespaces fails validation", name: "restore with both namespace in both includedNamespaces and excludedNamespaces fails validation",
restore: NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseNew).WithExcludedNamespace("another-1").Restore, restore: NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseNew).WithExcludedNamespace("another-1").Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseFailedValidation),
NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseFailedValidation).WithExcludedNamespace("another-1"). expectedValidationErrors: []string{"Invalid included/excluded namespace lists: excludes list cannot contain an item in the includes list: another-1"},
WithValidationError("Invalid included/excluded namespace lists: excludes list cannot contain an item in the includes list: another-1").
Restore,
},
}, },
{ {
name: "restore with resource in both includedResources and excludedResources fails validation", name: "restore with resource in both includedResources and excludedResources fails validation",
restore: NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseNew).WithExcludedResource("a-resource").Restore, restore: NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseNew).WithExcludedResource("a-resource").Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseFailedValidation),
NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseFailedValidation).WithExcludedResource("a-resource"). expectedValidationErrors: []string{"Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: a-resource"},
WithValidationError("Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: a-resource").
Restore,
},
}, },
{ {
name: "new restore with empty backup name fails validation", name: "new restore with empty backup name fails validation",
restore: NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseNew).Restore, restore: NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseNew).Restore,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseFailedValidation),
NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseFailedValidation). expectedValidationErrors: []string{"BackupName must be non-empty and correspond to the name of a backup in object storage."},
WithValidationError("BackupName must be non-empty and correspond to the name of a backup in object storage.").
Restore,
},
}, },
{ {
name: "restore with non-existent backup name fails", name: "restore with non-existent backup name fails",
restore: arktest.NewTestRestore("foo", "bar", api.RestorePhaseNew).WithBackup("backup-1").WithIncludedNamespace("ns-1").Restore, restore: arktest.NewTestRestore("foo", "bar", api.RestorePhaseNew).WithBackup("backup-1").WithIncludedNamespace("ns-1").Restore,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseInProgress),
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, expectedRestoreErrors: 1,
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted).
WithErrors(1).
Restore,
},
backupServiceGetBackupError: errors.New("no backup here"), backupServiceGetBackupError: errors.New("no backup here"),
}, },
{ {
name: "restorer throwing an error causes the restore to fail", name: "restorer throwing an error causes the restore to fail",
restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
restorerError: errors.New("blarg"), restorerError: errors.New("blarg"),
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseInProgress),
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, expectedRestoreErrors: 1,
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted). expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore,
WithErrors(1).
Restore,
},
expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore,
}, },
{ {
name: "valid restore gets executed", name: "valid restore gets executed",
restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseInProgress),
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore,
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted).Restore,
},
expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore,
}, },
{ {
@@ -226,34 +210,26 @@ func TestProcessRestore(t *testing.T) {
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
allowRestoreSnapshots: true, allowRestoreSnapshots: true,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseInProgress),
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore, expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore,
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted).WithRestorePVs(true).Restore,
},
expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore,
}, },
{ {
name: "restore with RestorePVs=true fails validation when allowRestoreSnapshots=false", name: "restore with RestorePVs=true fails validation when allowRestoreSnapshots=false",
restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).WithRestorePVs(true).Restore, restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).WithRestorePVs(true).Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseFailedValidation),
NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseFailedValidation). expectedValidationErrors: []string{"Server is not configured for PV snapshot restores"},
WithRestorePVs(true).
WithValidationError("Server is not configured for PV snapshot restores").
Restore,
},
}, },
{ {
name: "restoration of nodes is not supported", name: "restoration of nodes is not supported",
restore: NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseNew).Restore, restore: NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseNew).Restore,
backup: arktest.NewTestBackup().WithName("backup-1").Backup, backup: arktest.NewTestBackup().WithName("backup-1").Backup,
expectedErr: false, expectedErr: false,
expectedRestoreUpdates: []*api.Restore{ expectedPhase: string(api.RestorePhaseFailedValidation),
NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseFailedValidation). expectedValidationErrors: []string{
WithValidationError("nodes are a non-restorable resource"). "nodes are a non-restorable resource",
WithValidationError("Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: nodes"). "Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: nodes",
Restore,
}, },
}, },
} }
@@ -288,16 +264,34 @@ func TestProcessRestore(t *testing.T) {
if test.restore != nil { if test.restore != nil {
sharedInformers.Ark().V1().Restores().Informer().GetStore().Add(test.restore) sharedInformers.Ark().V1().Restores().Informer().GetStore().Add(test.restore)
// this is necessary so the Update() call returns the appropriate object // this is necessary so the Patch() call returns the appropriate object
client.PrependReactor("update", "restores", func(action core.Action) (bool, runtime.Object, error) { client.PrependReactor("patch", "restores", func(action core.Action) (bool, runtime.Object, error) {
obj := action.(core.UpdateAction).GetObject() if test.restore == nil {
// need to deep copy so we can test the backup state for each call to update return true, nil, nil
copy, err := scheme.Scheme.DeepCopy(obj) }
if err != nil {
patch := action.(core.PatchAction).GetPatch()
patchMap := make(map[string]interface{})
if err := json.Unmarshal(patch, &patchMap); err != nil {
t.Logf("error unmarshalling patch: %s\n", err)
return false, nil, err return false, nil, err
} }
ret := copy.(runtime.Object)
return true, ret, nil phase, err := collections.GetString(patchMap, "status.phase")
if err != nil {
t.Logf("error getting status.phase: %s\n", err)
return false, nil, err
}
res := test.restore.DeepCopy()
// these are the fields that we expect to be set by
// the controller
res.Status.Phase = api.RestorePhase(phase)
return true, res, nil
}) })
} }
@@ -346,32 +340,75 @@ func TestProcessRestore(t *testing.T) {
assert.Equal(t, test.expectedErr, err != nil, "got error %v", err) assert.Equal(t, test.expectedErr, err != nil, "got error %v", err)
if test.expectedRestoreUpdates != nil { actions := client.Actions()
var expectedActions []core.Action
for _, upd := range test.expectedRestoreUpdates { if test.expectedPhase == "" {
action := core.NewUpdateAction( require.Equal(t, 0, len(actions), "len(actions) should be zero")
api.SchemeGroupVersion.WithResource("restores"), return
upd.Namespace,
upd)
expectedActions = append(expectedActions, action)
}
assert.Equal(t, expectedActions, client.Actions())
} }
// validate Patch call 1 (setting phase, validation errs)
require.True(t, len(actions) > 0, "len(actions) is too small")
patchAction, ok := actions[0].(core.PatchAction)
require.True(t, ok, "action is not a PatchAction")
patch := make(map[string]interface{})
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
expectedStatusKeys := 1
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", test.expectedPhase), "patch's status.phase does not match")
if len(test.expectedValidationErrors) > 0 {
errs, err := collections.GetSlice(patch, "status.validationErrors")
require.NoError(t, err, "error getting patch's status.validationErrors")
var errStrings []string
for _, err := range errs {
errStrings = append(errStrings, err.(string))
}
assert.Equal(t, test.expectedValidationErrors, errStrings, "patch's status.validationErrors does not match")
expectedStatusKeys++
}
res, _ := collections.GetMap(patch, "status")
assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys")
// if we don't expect a restore, validate it wasn't called and exit the test
if test.expectedRestorerCall == nil { if test.expectedRestorerCall == nil {
assert.Empty(t, restorer.Calls) assert.Empty(t, restorer.Calls)
assert.Zero(t, restorer.calledWithArg) assert.Zero(t, restorer.calledWithArg)
} else { return
assert.Equal(t, 1, len(restorer.Calls))
// explicitly capturing the argument passed to Restore myself because
// I want to validate the called arg as of the time of calling, but
// the mock stores the pointer, which gets modified after
assert.Equal(t, *test.expectedRestorerCall, restorer.calledWithArg)
} }
assert.Equal(t, 1, len(restorer.Calls))
// validate Patch call 2 (setting phase)
patchAction, ok = actions[1].(core.PatchAction)
require.True(t, ok, "action is not a PatchAction")
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
assert.Equal(t, 1, len(patch), "patch has wrong number of keys")
res, _ = collections.GetMap(patch, "status")
expectedStatusKeys = 1
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(api.RestorePhaseCompleted)), "patch's status.phase does not match")
if test.expectedRestoreErrors != 0 {
assert.True(t, collections.HasKeyAndVal(patch, "status.errors", float64(test.expectedRestoreErrors)), "patch's status.errors does not match")
expectedStatusKeys++
}
assert.Equal(t, expectedStatusKeys, len(res), "patch's status has wrong number of keys")
// explicitly capturing the argument passed to Restore myself because
// I want to validate the called arg as of the time of calling, but
// the mock stores the pointer, which gets modified after
assert.Equal(t, *test.expectedRestorerCall, restorer.calledWithArg)
}) })
} }
} }

View File

@@ -18,6 +18,7 @@ package controller
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@@ -29,7 +30,9 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/workqueue" "k8s.io/client-go/util/workqueue"
@@ -50,7 +53,7 @@ type scheduleController struct {
queue workqueue.RateLimitingInterface queue workqueue.RateLimitingInterface
syncPeriod time.Duration syncPeriod time.Duration
clock clock.Clock clock clock.Clock
logger *logrus.Logger logger logrus.FieldLogger
} }
func NewScheduleController( func NewScheduleController(
@@ -58,7 +61,7 @@ func NewScheduleController(
backupsClient arkv1client.BackupsGetter, backupsClient arkv1client.BackupsGetter,
schedulesInformer informers.ScheduleInformer, schedulesInformer informers.ScheduleInformer,
syncPeriod time.Duration, syncPeriod time.Duration,
logger *logrus.Logger, logger logrus.FieldLogger,
) *scheduleController { ) *scheduleController {
if syncPeriod < time.Minute { if syncPeriod < time.Minute {
logger.WithField("syncPeriod", syncPeriod).Info("Provided schedule sync period is too short. Setting to 1 minute") logger.WithField("syncPeriod", syncPeriod).Info("Provided schedule sync period is too short. Setting to 1 minute")
@@ -230,6 +233,8 @@ func (controller *scheduleController) processSchedule(key string) error {
} }
logContext.Debug("Cloning schedule") logContext.Debug("Cloning schedule")
// store ref to original for creating patch
original := schedule
// don't modify items in the cache // don't modify items in the cache
schedule = schedule.DeepCopy() schedule = schedule.DeepCopy()
@@ -247,7 +252,7 @@ func (controller *scheduleController) processSchedule(key string) error {
// update status if it's changed // update status if it's changed
if currentPhase != schedule.Status.Phase { if currentPhase != schedule.Status.Phase {
updatedSchedule, err := controller.schedulesClient.Schedules(ns).Update(schedule) updatedSchedule, err := patchSchedule(original, schedule, controller.schedulesClient)
if err != nil { if err != nil {
return errors.Wrapf(err, "error updating Schedule phase to %s", schedule.Status.Phase) return errors.Wrapf(err, "error updating Schedule phase to %s", schedule.Status.Phase)
} }
@@ -266,7 +271,7 @@ func (controller *scheduleController) processSchedule(key string) error {
return nil return nil
} }
func parseCronSchedule(itm *api.Schedule, logger *logrus.Logger) (cron.Schedule, []string) { func parseCronSchedule(itm *api.Schedule, logger logrus.FieldLogger) (cron.Schedule, []string) {
var validationErrors []string var validationErrors []string
var schedule cron.Schedule var schedule cron.Schedule
@@ -330,11 +335,12 @@ func (controller *scheduleController) submitBackupIfDue(item *api.Schedule, cron
return errors.Wrap(err, "error creating Backup") return errors.Wrap(err, "error creating Backup")
} }
original := item
schedule := item.DeepCopy() schedule := item.DeepCopy()
schedule.Status.LastBackup = metav1.NewTime(now) schedule.Status.LastBackup = metav1.NewTime(now)
if _, err := controller.schedulesClient.Schedules(schedule.Namespace).Update(schedule); err != nil { if _, err := patchSchedule(original, schedule, controller.schedulesClient); err != nil {
return errors.Wrapf(err, "error updating Schedule's LastBackup time to %v", schedule.Status.LastBackup) return errors.Wrapf(err, "error updating Schedule's LastBackup time to %v", schedule.Status.LastBackup)
} }
@@ -365,3 +371,27 @@ func getBackup(item *api.Schedule, timestamp time.Time) *api.Backup {
return backup return backup
} }
func patchSchedule(original, updated *api.Schedule, client arkv1client.SchedulesGetter) (*api.Schedule, error) {
origBytes, err := json.Marshal(original)
if err != nil {
return nil, errors.Wrap(err, "error marshalling original schedule")
}
updatedBytes, err := json.Marshal(updated)
if err != nil {
return nil, errors.Wrap(err, "error marshalling updated schedule")
}
patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Schedule{})
if err != nil {
return nil, errors.Wrap(err, "error creating two-way merge patch for schedule")
}
res, err := client.Schedules(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes)
if err != nil {
return nil, errors.Wrap(err, "error patching schedule")
}
return res, nil
}

View File

@@ -17,6 +17,8 @@ limitations under the License.
package controller package controller
import ( import (
"encoding/json"
"fmt"
"testing" "testing"
"time" "time"
@@ -28,26 +30,27 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/client-go/kubernetes/scheme"
core "k8s.io/client-go/testing" core "k8s.io/client-go/testing"
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
api "github.com/heptio/ark/pkg/apis/ark/v1" api "github.com/heptio/ark/pkg/apis/ark/v1"
"github.com/heptio/ark/pkg/generated/clientset/versioned/fake" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake"
informers "github.com/heptio/ark/pkg/generated/informers/externalversions" informers "github.com/heptio/ark/pkg/generated/informers/externalversions"
. "github.com/heptio/ark/pkg/util/test" "github.com/heptio/ark/pkg/util/collections"
arktest "github.com/heptio/ark/pkg/util/test"
) )
func TestProcessSchedule(t *testing.T) { func TestProcessSchedule(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
scheduleKey string scheduleKey string
schedule *api.Schedule schedule *api.Schedule
fakeClockTime string fakeClockTime string
expectedErr bool expectedErr bool
expectedSchedulePhaseUpdate *api.Schedule expectedPhase string
expectedScheduleLastBackupUpdate *api.Schedule expectedValidationError string
expectedBackupCreate *api.Backup expectedBackupCreate *api.Backup
expectedLastBackup string
}{ }{
{ {
name: "invalid key returns error", name: "invalid key returns error",
@@ -61,70 +64,64 @@ func TestProcessSchedule(t *testing.T) {
}, },
{ {
name: "schedule with phase FailedValidation does not get processed", name: "schedule with phase FailedValidation does not get processed",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation).Schedule, schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation).Schedule,
expectedErr: false, expectedErr: false,
}, },
{ {
name: "schedule with phase New gets validated and failed if invalid", name: "schedule with phase New gets validated and failed if invalid",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).Schedule, schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).Schedule,
expectedErr: false, expectedErr: false,
expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). expectedPhase: string(api.SchedulePhaseFailedValidation),
WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, expectedValidationError: "Schedule must be a non-empty valid Cron expression",
}, },
{ {
name: "schedule with phase <blank> gets validated and failed if invalid", name: "schedule with phase <blank> gets validated and failed if invalid",
schedule: NewTestSchedule("ns", "name").Schedule, schedule: arktest.NewTestSchedule("ns", "name").Schedule,
expectedErr: false, expectedErr: false,
expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). expectedPhase: string(api.SchedulePhaseFailedValidation),
WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, expectedValidationError: "Schedule must be a non-empty valid Cron expression",
}, },
{ {
name: "schedule with phase Enabled gets re-validated and failed if invalid", name: "schedule with phase Enabled gets re-validated and failed if invalid",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).Schedule, schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).Schedule,
expectedErr: false, expectedErr: false,
expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). expectedPhase: string(api.SchedulePhaseFailedValidation),
WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, expectedValidationError: "Schedule must be a non-empty valid Cron expression",
}, },
{ {
name: "schedule with phase New gets validated and triggers a backup", name: "schedule with phase New gets validated and triggers a backup",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).WithCronSchedule("@every 5m").Schedule, schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).WithCronSchedule("@every 5m").Schedule,
fakeClockTime: "2017-01-01 12:00:00", fakeClockTime: "2017-01-01 12:00:00",
expectedErr: false, expectedErr: false,
expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule, expectedPhase: string(api.SchedulePhaseEnabled),
expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup,
expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). expectedLastBackup: "2017-01-01 12:00:00",
WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule,
}, },
{ {
name: "schedule with phase Enabled gets re-validated and triggers a backup if valid", name: "schedule with phase Enabled gets re-validated and triggers a backup if valid",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule, schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule,
fakeClockTime: "2017-01-01 12:00:00", fakeClockTime: "2017-01-01 12:00:00",
expectedErr: false, expectedErr: false,
expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup,
expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). expectedLastBackup: "2017-01-01 12:00:00",
WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule,
}, },
{ {
name: "schedule that's already run gets LastBackup updated", name: "schedule that's already run gets LastBackup updated",
schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).
WithCronSchedule("@every 5m").WithLastBackupTime("2000-01-01 00:00:00").Schedule, WithCronSchedule("@every 5m").WithLastBackupTime("2000-01-01 00:00:00").Schedule,
fakeClockTime: "2017-01-01 12:00:00", fakeClockTime: "2017-01-01 12:00:00",
expectedErr: false, expectedErr: false,
expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup,
expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). expectedLastBackup: "2017-01-01 12:00:00",
WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule,
}, },
} }
// flag.Set("logtostderr", "true")
// flag.Set("v", "4")
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
var ( var (
client = fake.NewSimpleClientset() client = fake.NewSimpleClientset()
sharedInformers = informers.NewSharedInformerFactory(client, 0) sharedInformers = informers.NewSharedInformerFactory(client, 0)
logger, _ = testlogger.NewNullLogger() logger = arktest.NewLogger()
) )
c := NewScheduleController( c := NewScheduleController(
@@ -148,16 +145,36 @@ func TestProcessSchedule(t *testing.T) {
if test.schedule != nil { if test.schedule != nil {
sharedInformers.Ark().V1().Schedules().Informer().GetStore().Add(test.schedule) sharedInformers.Ark().V1().Schedules().Informer().GetStore().Add(test.schedule)
// this is necessary so the Update() call returns the appropriate object // this is necessary so the Patch() call returns the appropriate object
client.PrependReactor("update", "schedules", func(action core.Action) (bool, runtime.Object, error) { client.PrependReactor("patch", "schedules", func(action core.Action) (bool, runtime.Object, error) {
obj := action.(core.UpdateAction).GetObject() var (
// need to deep copy so we can test the schedule state for each call to update patch = action.(core.PatchAction).GetPatch()
copy, err := scheme.Scheme.DeepCopy(obj) patchMap = make(map[string]interface{})
if err != nil { res = test.schedule.DeepCopy()
)
if err := json.Unmarshal(patch, &patchMap); err != nil {
t.Logf("error unmarshalling patch: %s\n", err)
return false, nil, err return false, nil, err
} }
ret := copy.(runtime.Object)
return true, ret, nil // these are the fields that may be updated by the controller
phase, err := collections.GetString(patchMap, "status.phase")
if err == nil {
res.Status.Phase = api.SchedulePhase(phase)
}
lastBackupStr, err := collections.GetString(patchMap, "status.lastBackup")
if err == nil {
parsed, err := time.Parse(time.RFC3339, lastBackupStr)
if err != nil {
t.Logf("error parsing status.lastBackup: %s\n", err)
return false, nil, err
}
res.Status.LastBackup = metav1.Time{Time: parsed}
}
return true, res, nil
}) })
} }
@@ -171,37 +188,88 @@ func TestProcessSchedule(t *testing.T) {
assert.Equal(t, test.expectedErr, err != nil, "got error %v", err) assert.Equal(t, test.expectedErr, err != nil, "got error %v", err)
expectedActions := make([]core.Action, 0) actions := client.Actions()
index := 0
if upd := test.expectedSchedulePhaseUpdate; upd != nil { if test.expectedPhase != "" {
action := core.NewUpdateAction( require.True(t, len(actions) > index, "len(actions) is too small")
api.SchemeGroupVersion.WithResource("schedules"),
upd.Namespace, patchAction, ok := actions[index].(core.PatchAction)
upd) require.True(t, ok, "action is not a PatchAction")
expectedActions = append(expectedActions, action)
patch := make(map[string]interface{})
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
assert.Equal(t, 1, len(patch), "patch has wrong number of keys")
expectedStatusKeys := 1
assert.True(t, collections.HasKeyAndVal(patch, "status.phase", test.expectedPhase), "patch's status.phase does not match")
if test.expectedValidationError != "" {
errs, err := collections.GetSlice(patch, "status.validationErrors")
require.NoError(t, err, "error getting patch's status.validationErrors")
require.Equal(t, 1, len(errs))
assert.Equal(t, test.expectedValidationError, errs[0].(string), "patch's status.validationErrors does not match")
expectedStatusKeys++
}
res, _ := collections.GetMap(patch, "status")
assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys")
index++
} }
if created := test.expectedBackupCreate; created != nil { if created := test.expectedBackupCreate; created != nil {
require.True(t, len(actions) > index, "len(actions) is too small")
action := core.NewCreateAction( action := core.NewCreateAction(
api.SchemeGroupVersion.WithResource("backups"), api.SchemeGroupVersion.WithResource("backups"),
created.Namespace, created.Namespace,
created) created)
expectedActions = append(expectedActions, action)
assert.Equal(t, action, actions[index])
index++
} }
if upd := test.expectedScheduleLastBackupUpdate; upd != nil { if test.expectedLastBackup != "" {
action := core.NewUpdateAction( require.True(t, len(actions) > index, "len(actions) is too small")
api.SchemeGroupVersion.WithResource("schedules"),
upd.Namespace,
upd)
expectedActions = append(expectedActions, action)
}
assert.Equal(t, expectedActions, client.Actions()) patchAction, ok := actions[index].(core.PatchAction)
require.True(t, ok, "action is not a PatchAction")
patch := make(map[string]interface{})
require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch")
assert.Equal(t, 1, len(patch), "patch has wrong number of keys")
lastBackup, _ := collections.GetValue(patch, "status.lastBackup")
fmt.Println(lastBackup)
assert.True(
t,
collections.HasKeyAndVal(patch, "status.lastBackup", parseTime(test.expectedLastBackup).UTC().Format(time.RFC3339)),
"patch's status.lastBackup does not match",
)
res, _ := collections.GetMap(patch, "status")
assert.Equal(t, 1, len(res), "patch's status has the wrong number of keys")
index++
}
}) })
} }
} }
func parseTime(timeString string) time.Time {
res, _ := time.Parse("2006-01-02 15:04:05", timeString)
return res
}
func TestGetNextRunTime(t *testing.T) { func TestGetNextRunTime(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -122,3 +122,14 @@ func Exists(root map[string]interface{}, path string) bool {
_, err := GetValue(root, path) _, err := GetValue(root, path)
return err == nil return err == nil
} }
// HasKeyAndVal returns true if root[path] exists and the value
// contained is equal to val, or false otherwise.
func HasKeyAndVal(root map[string]interface{}, path string, val interface{}) bool {
valObj, err := GetValue(root, path)
if err != nil {
return false
}
return valObj == val
}