diff --git a/pkg/controller/backup_controller.go b/pkg/controller/backup_controller.go index 4cc22507d..3ddc542c6 100644 --- a/pkg/controller/backup_controller.go +++ b/pkg/controller/backup_controller.go @@ -19,6 +19,7 @@ package controller import ( "bytes" "context" + "encoding/json" "fmt" "io/ioutil" "os" @@ -29,8 +30,10 @@ import ( "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/clock" kuberrs "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/tools/cache" "k8s.io/client-go/util/workqueue" @@ -60,7 +63,7 @@ type backupController struct { syncHandler func(backupName string) error queue workqueue.RateLimitingInterface clock clock.Clock - logger *logrus.Logger + logger logrus.FieldLogger pluginManager plugin.Manager } @@ -71,7 +74,7 @@ func NewBackupController( backupService cloudprovider.BackupService, bucket string, pvProviderExists bool, - logger *logrus.Logger, + logger logrus.FieldLogger, pluginManager plugin.Manager, ) Interface { c := &backupController{ @@ -223,6 +226,8 @@ func (controller *backupController) processBackup(key string) error { } logContext.Debug("Cloning backup") + // store ref to original for creating patch + original := backup // don't modify items in the cache backup = backup.DeepCopy() @@ -242,11 +247,13 @@ func (controller *backupController) processBackup(key string) error { } // update status - updatedBackup, err := controller.client.Backups(ns).Update(backup) + updatedBackup, err := patchBackup(original, backup, controller.client) if err != nil { return errors.Wrapf(err, "error updating Backup status to %s", backup.Status.Phase) } - backup = updatedBackup + // store ref to just-updated item for creating patch + original = updatedBackup + backup = updatedBackup.DeepCopy() if backup.Status.Phase == api.BackupPhaseFailedValidation { return nil @@ -260,13 +267,37 @@ func (controller *backupController) processBackup(key string) error { } logContext.Debug("Updating backup's final status") - if _, err = controller.client.Backups(ns).Update(backup); err != nil { + if _, err := patchBackup(original, backup, controller.client); err != nil { logContext.WithError(err).Error("error updating backup's final status") } return nil } +func patchBackup(original, updated *api.Backup, client arkv1client.BackupsGetter) (*api.Backup, error) { + origBytes, err := json.Marshal(original) + if err != nil { + return nil, errors.Wrap(err, "error marshalling original backup") + } + + updatedBytes, err := json.Marshal(updated) + if err != nil { + return nil, errors.Wrap(err, "error marshalling updated backup") + } + + patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Backup{}) + if err != nil { + return nil, errors.Wrap(err, "error creating two-way merge patch for backup") + } + + res, err := client.Backups(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes) + if err != nil { + return nil, errors.Wrap(err, "error patching backup") + } + + return res, nil +} + func (controller *backupController) getValidationErrors(itm *api.Backup) []string { var validationErrors []string diff --git a/pkg/controller/backup_controller_test.go b/pkg/controller/backup_controller_test.go index 5a455eba8..f0c20d387 100644 --- a/pkg/controller/backup_controller_test.go +++ b/pkg/controller/backup_controller_test.go @@ -17,6 +17,7 @@ limitations under the License. package controller import ( + "encoding/json" "io" "testing" "time" @@ -25,7 +26,6 @@ import ( "k8s.io/apimachinery/pkg/util/clock" core "k8s.io/client-go/testing" - testlogger "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -34,10 +34,10 @@ import ( "github.com/heptio/ark/pkg/backup" "github.com/heptio/ark/pkg/cloudprovider" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake" - "github.com/heptio/ark/pkg/generated/clientset/versioned/scheme" informers "github.com/heptio/ark/pkg/generated/informers/externalversions" "github.com/heptio/ark/pkg/restore" - . "github.com/heptio/ark/pkg/util/test" + "github.com/heptio/ark/pkg/util/collections" + arktest "github.com/heptio/ark/pkg/util/test" ) type fakeBackupper struct { @@ -56,7 +56,7 @@ func TestProcessBackup(t *testing.T) { expectError bool expectedIncludes []string expectedExcludes []string - backup *TestBackup + backup *arktest.TestBackup expectBackup bool allowSnapshots bool }{ @@ -73,49 +73,49 @@ func TestProcessBackup(t *testing.T) { { name: "do not process phase FailedValidation", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailedValidation), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailedValidation), expectBackup: false, }, { name: "do not process phase InProgress", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseInProgress), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseInProgress), expectBackup: false, }, { name: "do not process phase Completed", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseCompleted), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseCompleted), expectBackup: false, }, { name: "do not process phase Failed", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailed), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseFailed), expectBackup: false, }, { name: "do not process phase other", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase("arg"), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase("arg"), expectBackup: false, }, { name: "invalid included/excluded resources fails validation", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("foo").WithExcludedResources("foo"), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("foo").WithExcludedResources("foo"), expectBackup: false, }, { name: "invalid included/excluded namespaces fails validation", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("foo").WithExcludedNamespaces("foo"), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("foo").WithExcludedNamespaces("foo"), expectBackup: false, }, { name: "make sure specified included and excluded resources are honored", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("i", "j").WithExcludedResources("k", "l"), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedResources("i", "j").WithExcludedResources("k", "l"), expectedIncludes: []string{"i", "j"}, expectedExcludes: []string{"k", "l"}, expectBackup: true, @@ -123,25 +123,25 @@ func TestProcessBackup(t *testing.T) { { name: "if includednamespaces are specified, don't default to *", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("ns-1"), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithIncludedNamespaces("ns-1"), expectBackup: true, }, { name: "ttl", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithTTL(10 * time.Minute), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithTTL(10 * time.Minute), expectBackup: true, }, { name: "backup with SnapshotVolumes when allowSnapshots=false fails validation", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), expectBackup: false, }, { name: "backup with SnapshotVolumes when allowSnapshots=true gets executed", key: "heptio-ark/backup1", - backup: NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), + backup: arktest.NewTestBackup().WithName("backup1").WithPhase(v1.BackupPhaseNew).WithSnapshotVolumes(true), allowSnapshots: true, expectBackup: true, }, @@ -152,9 +152,9 @@ func TestProcessBackup(t *testing.T) { var ( client = fake.NewSimpleClientset() backupper = &fakeBackupper{} - cloudBackups = &BackupService{} + cloudBackups = &arktest.BackupService{} sharedInformers = informers.NewSharedInformerFactory(client, 0) - logger, _ = testlogger.NewNullLogger() + logger = arktest.NewLogger() pluginManager = &MockManager{} ) @@ -182,9 +182,7 @@ func TestProcessBackup(t *testing.T) { } // set up a Backup object to represent what we expect to be passed to backupper.Backup() - copy, err := scheme.Scheme.Copy(test.backup.Backup) - assert.NoError(t, err, "copy error") - backup := copy.(*v1.Backup) + backup := test.backup.DeepCopy() backup.Spec.IncludedResources = test.expectedIncludes backup.Spec.ExcludedResources = test.expectedExcludes backup.Spec.IncludedNamespaces = test.backup.Spec.IncludedNamespaces @@ -200,16 +198,35 @@ func TestProcessBackup(t *testing.T) { pluginManager.On("CloseBackupItemActions", backup.Name).Return(nil) } - // this is necessary so the Update() call returns the appropriate object - client.PrependReactor("update", "backups", func(action core.Action) (bool, runtime.Object, error) { - obj := action.(core.UpdateAction).GetObject() - // need to deep copy so we can test the backup state for each call to update - copy, err := scheme.Scheme.DeepCopy(obj) - if err != nil { + // this is necessary so the Patch() call returns the appropriate object + client.PrependReactor("patch", "backups", func(action core.Action) (bool, runtime.Object, error) { + if test.backup == nil { + return true, nil, nil + } + + patch := action.(core.PatchAction).GetPatch() + patchMap := make(map[string]interface{}) + + if err := json.Unmarshal(patch, &patchMap); err != nil { + t.Logf("error unmarshalling patch: %s\n", err) return false, nil, err } - ret := copy.(runtime.Object) - return true, ret, nil + + phase, err := collections.GetString(patchMap, "status.phase") + if err != nil { + t.Logf("error getting status.phase: %s\n", err) + return false, nil, err + } + + res := test.backup.DeepCopy() + + // these are the fields that we expect to be set by + // the controller + res.Status.Version = 1 + res.Status.Expiration.Time = expiration + res.Status.Phase = v1.BackupPhase(phase) + + return true, res, nil }) // method under test @@ -227,41 +244,41 @@ func TestProcessBackup(t *testing.T) { return } - expectedActions := []core.Action{ - core.NewUpdateAction( - v1.SchemeGroupVersion.WithResource("backups"), - v1.DefaultNamespace, - NewTestBackup(). - WithName(test.backup.Name). - WithPhase(v1.BackupPhaseInProgress). - WithIncludedResources(test.expectedIncludes...). - WithExcludedResources(test.expectedExcludes...). - WithIncludedNamespaces(test.backup.Spec.IncludedNamespaces...). - WithTTL(test.backup.Spec.TTL.Duration). - WithSnapshotVolumesPointer(test.backup.Spec.SnapshotVolumes). - WithExpiration(expiration). - WithVersion(1). - Backup, - ), + actions := client.Actions() + require.Equal(t, 2, len(actions)) - core.NewUpdateAction( - v1.SchemeGroupVersion.WithResource("backups"), - v1.DefaultNamespace, - NewTestBackup(). - WithName(test.backup.Name). - WithPhase(v1.BackupPhaseCompleted). - WithIncludedResources(test.expectedIncludes...). - WithExcludedResources(test.expectedExcludes...). - WithIncludedNamespaces(test.backup.Spec.IncludedNamespaces...). - WithTTL(test.backup.Spec.TTL.Duration). - WithSnapshotVolumesPointer(test.backup.Spec.SnapshotVolumes). - WithExpiration(expiration). - WithVersion(1). - Backup, - ), + // validate Patch call 1 (setting version, expiration, and phase) + patchAction, ok := actions[0].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + patch := make(map[string]interface{}) + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + assert.Equal(t, 1, len(patch), "patch has wrong number of keys") + + expectedStatusKeys := 2 + if test.backup.Spec.TTL.Duration > 0 { + assert.True(t, collections.HasKeyAndVal(patch, "status.expiration", expiration.UTC().Format(time.RFC3339)), "patch's status.expiration does not match") + expectedStatusKeys = 3 } - assert.Equal(t, expectedActions, client.Actions()) + assert.True(t, collections.HasKeyAndVal(patch, "status.version", float64(1))) + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(v1.BackupPhaseInProgress)), "patch's status.phase does not match") + + res, _ := collections.GetMap(patch, "status") + assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys") + + // validate Patch call 2 (setting phase) + patchAction, ok = actions[1].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + assert.Equal(t, 1, len(patch), "patch has wrong number of keys") + + res, _ = collections.GetMap(patch, "status") + assert.Equal(t, 1, len(res), "patch's status has the wrong number of keys") + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(v1.BackupPhaseCompleted)), "patch's status.phase does not match") }) } } diff --git a/pkg/controller/download_request_controller.go b/pkg/controller/download_request_controller.go index 731bffc46..e5358ded8 100644 --- a/pkg/controller/download_request_controller.go +++ b/pkg/controller/download_request_controller.go @@ -18,6 +18,7 @@ package controller import ( "context" + "encoding/json" "sync" "time" @@ -27,7 +28,9 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/clock" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/tools/cache" "k8s.io/client-go/util/workqueue" @@ -220,7 +223,7 @@ func (c *downloadRequestController) generatePreSignedURL(downloadRequest *v1.Dow update.Status.Phase = v1.DownloadRequestPhaseProcessed update.Status.Expiration = metav1.NewTime(c.clock.Now().Add(signedURLTTL)) - _, err = c.downloadRequestClient.DownloadRequests(update.Namespace).Update(update) + _, err = patchDownloadRequest(downloadRequest, update, c.downloadRequestClient) return errors.WithStack(err) } @@ -256,3 +259,27 @@ func (c *downloadRequestController) resync() { c.queue.Add(key) } } + +func patchDownloadRequest(original, updated *v1.DownloadRequest, client arkv1client.DownloadRequestsGetter) (*v1.DownloadRequest, error) { + origBytes, err := json.Marshal(original) + if err != nil { + return nil, errors.Wrap(err, "error marshalling original download request") + } + + updatedBytes, err := json.Marshal(updated) + if err != nil { + return nil, errors.Wrap(err, "error marshalling updated download request") + } + + patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, v1.DownloadRequest{}) + if err != nil { + return nil, errors.Wrap(err, "error creating two-way merge patch for download request") + } + + res, err := client.DownloadRequests(v1.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes) + if err != nil { + return nil, errors.Wrap(err, "error patching download request") + } + + return res, nil +} diff --git a/pkg/controller/download_request_controller_test.go b/pkg/controller/download_request_controller_test.go index a4738fb07..12a975067 100644 --- a/pkg/controller/download_request_controller_test.go +++ b/pkg/controller/download_request_controller_test.go @@ -17,6 +17,7 @@ limitations under the License. package controller import ( + "encoding/json" "testing" "time" @@ -25,12 +26,12 @@ import ( "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" core "k8s.io/client-go/testing" "github.com/heptio/ark/pkg/apis/ark/v1" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake" informers "github.com/heptio/ark/pkg/generated/informers/externalversions" + "github.com/heptio/ark/pkg/util/collections" "github.com/heptio/ark/pkg/util/test" ) @@ -111,37 +112,28 @@ func TestProcessDownloadRequest(t *testing.T) { logger, ).(*downloadRequestController) + var downloadRequest *v1.DownloadRequest + if tc.expectedPhase == v1.DownloadRequestPhaseProcessed { target := v1.DownloadTarget{ Kind: tc.targetKind, Name: tc.targetName, } - downloadRequestsInformer.Informer().GetStore().Add( - &v1.DownloadRequest{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: v1.DefaultNamespace, - Name: "dr1", - }, - Spec: v1.DownloadRequestSpec{ - Target: target, - }, + downloadRequest = &v1.DownloadRequest{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: v1.DefaultNamespace, + Name: "dr1", }, - ) + Spec: v1.DownloadRequestSpec{ + Target: target, + }, + } + downloadRequestsInformer.Informer().GetStore().Add(downloadRequest) backupService.On("CreateSignedURL", target, "bucket", 10*time.Minute).Return("signedURL", nil) } - var updatedRequest *v1.DownloadRequest - - client.PrependReactor("update", "downloadrequests", func(action core.Action) (bool, runtime.Object, error) { - obj := action.(core.UpdateAction).GetObject() - r, ok := obj.(*v1.DownloadRequest) - require.True(t, ok) - updatedRequest = r - return true, obj, nil - }) - // method under test err := c.processDownloadRequest(tc.key) @@ -152,16 +144,37 @@ func TestProcessDownloadRequest(t *testing.T) { require.NoError(t, err) - var ( - updatedPhase v1.DownloadRequestPhase - updatedURL string - ) - if updatedRequest != nil { - updatedPhase = updatedRequest.Status.Phase - updatedURL = updatedRequest.Status.DownloadURL + actions := client.Actions() + + // if we don't expect a phase update, this means + // we don't expect any actions to take place + if tc.expectedPhase == "" { + require.Equal(t, 0, len(actions)) + return } - assert.Equal(t, tc.expectedPhase, updatedPhase) - assert.Equal(t, tc.expectedURL, updatedURL) + + // otherwise, we should get exactly 1 patch + require.Equal(t, 1, len(actions)) + patchAction, ok := actions[0].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + patch := make(map[string]interface{}) + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + // check the URL + assert.True(t, collections.HasKeyAndVal(patch, "status.downloadURL", tc.expectedURL), "patch's status.downloadURL does not match") + + // check the Phase + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(tc.expectedPhase)), "patch's status.phase does not match") + + // check that Expiration exists + // TODO pass a fake clock to the controller and verify + // the expiration value + assert.True(t, collections.Exists(patch, "status.expiration"), "patch's status.expiration does not exist") + + // we expect 3 total updates. + res, _ := collections.GetMap(patch, "status") + assert.Equal(t, 3, len(res), "patch's status has the wrong number of keys") }) } } diff --git a/pkg/controller/restore_controller.go b/pkg/controller/restore_controller.go index 3202e7691..47731d1a0 100644 --- a/pkg/controller/restore_controller.go +++ b/pkg/controller/restore_controller.go @@ -31,7 +31,9 @@ import ( "github.com/sirupsen/logrus" apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/tools/cache" "k8s.io/client-go/util/workqueue" @@ -230,6 +232,8 @@ func (controller *restoreController) processRestore(key string) error { } logContext.Debug("Cloning Restore") + // store ref to original for creating patch + original := restore // don't modify items in the cache restore = restore.DeepCopy() @@ -248,11 +252,13 @@ func (controller *restoreController) processRestore(key string) error { } // update status - updatedRestore, err := controller.restoreClient.Restores(ns).Update(restore) + updatedRestore, err := patchRestore(original, restore, controller.restoreClient) if err != nil { return errors.Wrapf(err, "error updating Restore phase to %s", restore.Status.Phase) } - restore = updatedRestore + // store ref to just-updated item for creating patch + original = updatedRestore + restore = updatedRestore.DeepCopy() if restore.Status.Phase == api.RestorePhaseFailedValidation { return nil @@ -276,7 +282,7 @@ func (controller *restoreController) processRestore(key string) error { restore.Status.Phase = api.RestorePhaseCompleted logContext.Debug("Updating Restore final status") - if _, err = controller.restoreClient.Restores(ns).Update(restore); err != nil { + if _, err = patchRestore(original, restore, controller.restoreClient); err != nil { logContext.WithError(errors.WithStack(err)).Info("Error updating Restore final status") } @@ -472,3 +478,27 @@ func downloadToTempFile(backupName string, backupService cloudprovider.BackupSer return file, nil } + +func patchRestore(original, updated *api.Restore, client arkv1client.RestoresGetter) (*api.Restore, error) { + origBytes, err := json.Marshal(original) + if err != nil { + return nil, errors.Wrap(err, "error marshalling original restore") + } + + updatedBytes, err := json.Marshal(updated) + if err != nil { + return nil, errors.Wrap(err, "error marshalling updated restore") + } + + patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Restore{}) + if err != nil { + return nil, errors.Wrap(err, "error creating two-way merge patch for restore") + } + + res, err := client.Restores(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes) + if err != nil { + return nil, errors.Wrap(err, "error patching restore") + } + + return res, nil +} diff --git a/pkg/controller/restore_controller_test.go b/pkg/controller/restore_controller_test.go index c1f291e88..2d0ac3f9c 100644 --- a/pkg/controller/restore_controller_test.go +++ b/pkg/controller/restore_controller_test.go @@ -18,6 +18,7 @@ package controller import ( "bytes" + "encoding/json" "errors" "io" "io/ioutil" @@ -25,9 +26,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/client-go/kubernetes/scheme" core "k8s.io/client-go/testing" "k8s.io/client-go/tools/cache" @@ -35,6 +36,7 @@ import ( "github.com/heptio/ark/pkg/generated/clientset/versioned/fake" informers "github.com/heptio/ark/pkg/generated/informers/externalversions" "github.com/heptio/ark/pkg/restore" + "github.com/heptio/ark/pkg/util/collections" arktest "github.com/heptio/ark/pkg/util/test" ) @@ -120,7 +122,9 @@ func TestProcessRestore(t *testing.T) { restorerError error allowRestoreSnapshots bool expectedErr bool - expectedRestoreUpdates []*api.Restore + expectedPhase string + expectedValidationErrors []string + expectedRestoreErrors int expectedRestorerCall *api.Restore backupServiceGetBackupError error uploadLogError error @@ -151,73 +155,53 @@ func TestProcessRestore(t *testing.T) { expectedErr: false, }, { - name: "restore with both namespace in both includedNamespaces and excludedNamespaces fails validation", - restore: NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseNew).WithExcludedNamespace("another-1").Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseFailedValidation).WithExcludedNamespace("another-1"). - WithValidationError("Invalid included/excluded namespace lists: excludes list cannot contain an item in the includes list: another-1"). - Restore, - }, + name: "restore with both namespace in both includedNamespaces and excludedNamespaces fails validation", + restore: NewRestore("foo", "bar", "backup-1", "another-1", "*", api.RestorePhaseNew).WithExcludedNamespace("another-1").Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + expectedErr: false, + expectedPhase: string(api.RestorePhaseFailedValidation), + expectedValidationErrors: []string{"Invalid included/excluded namespace lists: excludes list cannot contain an item in the includes list: another-1"}, }, { - name: "restore with resource in both includedResources and excludedResources fails validation", - restore: NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseNew).WithExcludedResource("a-resource").Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseFailedValidation).WithExcludedResource("a-resource"). - WithValidationError("Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: a-resource"). - Restore, - }, + name: "restore with resource in both includedResources and excludedResources fails validation", + restore: NewRestore("foo", "bar", "backup-1", "*", "a-resource", api.RestorePhaseNew).WithExcludedResource("a-resource").Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + expectedErr: false, + expectedPhase: string(api.RestorePhaseFailedValidation), + expectedValidationErrors: []string{"Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: a-resource"}, }, { - name: "new restore with empty backup name fails validation", - restore: NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseNew).Restore, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseFailedValidation). - WithValidationError("BackupName must be non-empty and correspond to the name of a backup in object storage."). - Restore, - }, + name: "new restore with empty backup name fails validation", + restore: NewRestore("foo", "bar", "", "ns-1", "", api.RestorePhaseNew).Restore, + expectedErr: false, + expectedPhase: string(api.RestorePhaseFailedValidation), + expectedValidationErrors: []string{"BackupName must be non-empty and correspond to the name of a backup in object storage."}, }, { - name: "restore with non-existent backup name fails", - restore: arktest.NewTestRestore("foo", "bar", api.RestorePhaseNew).WithBackup("backup-1").WithIncludedNamespace("ns-1").Restore, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted). - WithErrors(1). - Restore, - }, + name: "restore with non-existent backup name fails", + restore: arktest.NewTestRestore("foo", "bar", api.RestorePhaseNew).WithBackup("backup-1").WithIncludedNamespace("ns-1").Restore, + expectedErr: false, + expectedPhase: string(api.RestorePhaseInProgress), + expectedRestoreErrors: 1, backupServiceGetBackupError: errors.New("no backup here"), }, { - name: "restorer throwing an error causes the restore to fail", - restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - restorerError: errors.New("blarg"), - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted). - WithErrors(1). - Restore, - }, - expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, + name: "restorer throwing an error causes the restore to fail", + restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + restorerError: errors.New("blarg"), + expectedErr: false, + expectedPhase: string(api.RestorePhaseInProgress), + expectedRestoreErrors: 1, + expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, }, { - name: "valid restore gets executed", - restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted).Restore, - }, + name: "valid restore gets executed", + restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + expectedErr: false, + expectedPhase: string(api.RestorePhaseInProgress), expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).Restore, }, { @@ -226,34 +210,26 @@ func TestProcessRestore(t *testing.T) { backup: arktest.NewTestBackup().WithName("backup-1").Backup, allowRestoreSnapshots: true, expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore, - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseCompleted).WithRestorePVs(true).Restore, - }, - expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore, + expectedPhase: string(api.RestorePhaseInProgress), + expectedRestorerCall: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseInProgress).WithRestorePVs(true).Restore, }, { - name: "restore with RestorePVs=true fails validation when allowRestoreSnapshots=false", - restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).WithRestorePVs(true).Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseFailedValidation). - WithRestorePVs(true). - WithValidationError("Server is not configured for PV snapshot restores"). - Restore, - }, + name: "restore with RestorePVs=true fails validation when allowRestoreSnapshots=false", + restore: NewRestore("foo", "bar", "backup-1", "ns-1", "", api.RestorePhaseNew).WithRestorePVs(true).Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + expectedErr: false, + expectedPhase: string(api.RestorePhaseFailedValidation), + expectedValidationErrors: []string{"Server is not configured for PV snapshot restores"}, }, { - name: "restoration of nodes is not supported", - restore: NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseNew).Restore, - backup: arktest.NewTestBackup().WithName("backup-1").Backup, - expectedErr: false, - expectedRestoreUpdates: []*api.Restore{ - NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseFailedValidation). - WithValidationError("nodes are a non-restorable resource"). - WithValidationError("Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: nodes"). - Restore, + name: "restoration of nodes is not supported", + restore: NewRestore("foo", "bar", "backup-1", "ns-1", "nodes", api.RestorePhaseNew).Restore, + backup: arktest.NewTestBackup().WithName("backup-1").Backup, + expectedErr: false, + expectedPhase: string(api.RestorePhaseFailedValidation), + expectedValidationErrors: []string{ + "nodes are a non-restorable resource", + "Invalid included/excluded resource lists: excludes list cannot contain an item in the includes list: nodes", }, }, } @@ -288,16 +264,34 @@ func TestProcessRestore(t *testing.T) { if test.restore != nil { sharedInformers.Ark().V1().Restores().Informer().GetStore().Add(test.restore) - // this is necessary so the Update() call returns the appropriate object - client.PrependReactor("update", "restores", func(action core.Action) (bool, runtime.Object, error) { - obj := action.(core.UpdateAction).GetObject() - // need to deep copy so we can test the backup state for each call to update - copy, err := scheme.Scheme.DeepCopy(obj) - if err != nil { + // this is necessary so the Patch() call returns the appropriate object + client.PrependReactor("patch", "restores", func(action core.Action) (bool, runtime.Object, error) { + if test.restore == nil { + return true, nil, nil + } + + patch := action.(core.PatchAction).GetPatch() + patchMap := make(map[string]interface{}) + + if err := json.Unmarshal(patch, &patchMap); err != nil { + t.Logf("error unmarshalling patch: %s\n", err) return false, nil, err } - ret := copy.(runtime.Object) - return true, ret, nil + + phase, err := collections.GetString(patchMap, "status.phase") + if err != nil { + t.Logf("error getting status.phase: %s\n", err) + return false, nil, err + } + + res := test.restore.DeepCopy() + + // these are the fields that we expect to be set by + // the controller + + res.Status.Phase = api.RestorePhase(phase) + + return true, res, nil }) } @@ -346,32 +340,75 @@ func TestProcessRestore(t *testing.T) { assert.Equal(t, test.expectedErr, err != nil, "got error %v", err) - if test.expectedRestoreUpdates != nil { - var expectedActions []core.Action + actions := client.Actions() - for _, upd := range test.expectedRestoreUpdates { - action := core.NewUpdateAction( - api.SchemeGroupVersion.WithResource("restores"), - upd.Namespace, - upd) - - expectedActions = append(expectedActions, action) - } - - assert.Equal(t, expectedActions, client.Actions()) + if test.expectedPhase == "" { + require.Equal(t, 0, len(actions), "len(actions) should be zero") + return } + // validate Patch call 1 (setting phase, validation errs) + require.True(t, len(actions) > 0, "len(actions) is too small") + + patchAction, ok := actions[0].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + patch := make(map[string]interface{}) + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + expectedStatusKeys := 1 + + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", test.expectedPhase), "patch's status.phase does not match") + + if len(test.expectedValidationErrors) > 0 { + errs, err := collections.GetSlice(patch, "status.validationErrors") + require.NoError(t, err, "error getting patch's status.validationErrors") + + var errStrings []string + for _, err := range errs { + errStrings = append(errStrings, err.(string)) + } + + assert.Equal(t, test.expectedValidationErrors, errStrings, "patch's status.validationErrors does not match") + + expectedStatusKeys++ + } + + res, _ := collections.GetMap(patch, "status") + assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys") + + // if we don't expect a restore, validate it wasn't called and exit the test if test.expectedRestorerCall == nil { assert.Empty(t, restorer.Calls) assert.Zero(t, restorer.calledWithArg) - } else { - assert.Equal(t, 1, len(restorer.Calls)) - - // explicitly capturing the argument passed to Restore myself because - // I want to validate the called arg as of the time of calling, but - // the mock stores the pointer, which gets modified after - assert.Equal(t, *test.expectedRestorerCall, restorer.calledWithArg) + return } + assert.Equal(t, 1, len(restorer.Calls)) + + // validate Patch call 2 (setting phase) + patchAction, ok = actions[1].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + assert.Equal(t, 1, len(patch), "patch has wrong number of keys") + + res, _ = collections.GetMap(patch, "status") + expectedStatusKeys = 1 + + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", string(api.RestorePhaseCompleted)), "patch's status.phase does not match") + + if test.expectedRestoreErrors != 0 { + assert.True(t, collections.HasKeyAndVal(patch, "status.errors", float64(test.expectedRestoreErrors)), "patch's status.errors does not match") + expectedStatusKeys++ + } + + assert.Equal(t, expectedStatusKeys, len(res), "patch's status has wrong number of keys") + + // explicitly capturing the argument passed to Restore myself because + // I want to validate the called arg as of the time of calling, but + // the mock stores the pointer, which gets modified after + assert.Equal(t, *test.expectedRestorerCall, restorer.calledWithArg) }) } } diff --git a/pkg/controller/schedule_controller.go b/pkg/controller/schedule_controller.go index 83c1249c7..30466eb02 100644 --- a/pkg/controller/schedule_controller.go +++ b/pkg/controller/schedule_controller.go @@ -18,6 +18,7 @@ package controller import ( "context" + "encoding/json" "fmt" "sync" "time" @@ -29,7 +30,9 @@ import ( apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/clock" + "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/tools/cache" "k8s.io/client-go/util/workqueue" @@ -50,7 +53,7 @@ type scheduleController struct { queue workqueue.RateLimitingInterface syncPeriod time.Duration clock clock.Clock - logger *logrus.Logger + logger logrus.FieldLogger } func NewScheduleController( @@ -58,7 +61,7 @@ func NewScheduleController( backupsClient arkv1client.BackupsGetter, schedulesInformer informers.ScheduleInformer, syncPeriod time.Duration, - logger *logrus.Logger, + logger logrus.FieldLogger, ) *scheduleController { if syncPeriod < time.Minute { logger.WithField("syncPeriod", syncPeriod).Info("Provided schedule sync period is too short. Setting to 1 minute") @@ -230,6 +233,8 @@ func (controller *scheduleController) processSchedule(key string) error { } logContext.Debug("Cloning schedule") + // store ref to original for creating patch + original := schedule // don't modify items in the cache schedule = schedule.DeepCopy() @@ -247,7 +252,7 @@ func (controller *scheduleController) processSchedule(key string) error { // update status if it's changed if currentPhase != schedule.Status.Phase { - updatedSchedule, err := controller.schedulesClient.Schedules(ns).Update(schedule) + updatedSchedule, err := patchSchedule(original, schedule, controller.schedulesClient) if err != nil { return errors.Wrapf(err, "error updating Schedule phase to %s", schedule.Status.Phase) } @@ -266,7 +271,7 @@ func (controller *scheduleController) processSchedule(key string) error { return nil } -func parseCronSchedule(itm *api.Schedule, logger *logrus.Logger) (cron.Schedule, []string) { +func parseCronSchedule(itm *api.Schedule, logger logrus.FieldLogger) (cron.Schedule, []string) { var validationErrors []string var schedule cron.Schedule @@ -330,11 +335,12 @@ func (controller *scheduleController) submitBackupIfDue(item *api.Schedule, cron return errors.Wrap(err, "error creating Backup") } + original := item schedule := item.DeepCopy() schedule.Status.LastBackup = metav1.NewTime(now) - if _, err := controller.schedulesClient.Schedules(schedule.Namespace).Update(schedule); err != nil { + if _, err := patchSchedule(original, schedule, controller.schedulesClient); err != nil { return errors.Wrapf(err, "error updating Schedule's LastBackup time to %v", schedule.Status.LastBackup) } @@ -365,3 +371,27 @@ func getBackup(item *api.Schedule, timestamp time.Time) *api.Backup { return backup } + +func patchSchedule(original, updated *api.Schedule, client arkv1client.SchedulesGetter) (*api.Schedule, error) { + origBytes, err := json.Marshal(original) + if err != nil { + return nil, errors.Wrap(err, "error marshalling original schedule") + } + + updatedBytes, err := json.Marshal(updated) + if err != nil { + return nil, errors.Wrap(err, "error marshalling updated schedule") + } + + patchBytes, err := strategicpatch.CreateTwoWayMergePatch(origBytes, updatedBytes, api.Schedule{}) + if err != nil { + return nil, errors.Wrap(err, "error creating two-way merge patch for schedule") + } + + res, err := client.Schedules(api.DefaultNamespace).Patch(original.Name, types.MergePatchType, patchBytes) + if err != nil { + return nil, errors.Wrap(err, "error patching schedule") + } + + return res, nil +} diff --git a/pkg/controller/schedule_controller_test.go b/pkg/controller/schedule_controller_test.go index aee3e09cd..9c1133484 100644 --- a/pkg/controller/schedule_controller_test.go +++ b/pkg/controller/schedule_controller_test.go @@ -17,6 +17,8 @@ limitations under the License. package controller import ( + "encoding/json" + "fmt" "testing" "time" @@ -28,26 +30,27 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/clock" - "k8s.io/client-go/kubernetes/scheme" core "k8s.io/client-go/testing" "k8s.io/client-go/tools/cache" api "github.com/heptio/ark/pkg/apis/ark/v1" "github.com/heptio/ark/pkg/generated/clientset/versioned/fake" informers "github.com/heptio/ark/pkg/generated/informers/externalversions" - . "github.com/heptio/ark/pkg/util/test" + "github.com/heptio/ark/pkg/util/collections" + arktest "github.com/heptio/ark/pkg/util/test" ) func TestProcessSchedule(t *testing.T) { tests := []struct { - name string - scheduleKey string - schedule *api.Schedule - fakeClockTime string - expectedErr bool - expectedSchedulePhaseUpdate *api.Schedule - expectedScheduleLastBackupUpdate *api.Schedule - expectedBackupCreate *api.Backup + name string + scheduleKey string + schedule *api.Schedule + fakeClockTime string + expectedErr bool + expectedPhase string + expectedValidationError string + expectedBackupCreate *api.Backup + expectedLastBackup string }{ { name: "invalid key returns error", @@ -61,70 +64,64 @@ func TestProcessSchedule(t *testing.T) { }, { name: "schedule with phase FailedValidation does not get processed", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation).Schedule, + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation).Schedule, expectedErr: false, }, { - name: "schedule with phase New gets validated and failed if invalid", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).Schedule, - expectedErr: false, - expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). - WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, + name: "schedule with phase New gets validated and failed if invalid", + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).Schedule, + expectedErr: false, + expectedPhase: string(api.SchedulePhaseFailedValidation), + expectedValidationError: "Schedule must be a non-empty valid Cron expression", }, { - name: "schedule with phase gets validated and failed if invalid", - schedule: NewTestSchedule("ns", "name").Schedule, - expectedErr: false, - expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). - WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, + name: "schedule with phase gets validated and failed if invalid", + schedule: arktest.NewTestSchedule("ns", "name").Schedule, + expectedErr: false, + expectedPhase: string(api.SchedulePhaseFailedValidation), + expectedValidationError: "Schedule must be a non-empty valid Cron expression", }, { - name: "schedule with phase Enabled gets re-validated and failed if invalid", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).Schedule, - expectedErr: false, - expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseFailedValidation). - WithValidationError("Schedule must be a non-empty valid Cron expression").Schedule, + name: "schedule with phase Enabled gets re-validated and failed if invalid", + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).Schedule, + expectedErr: false, + expectedPhase: string(api.SchedulePhaseFailedValidation), + expectedValidationError: "Schedule must be a non-empty valid Cron expression", }, { - name: "schedule with phase New gets validated and triggers a backup", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).WithCronSchedule("@every 5m").Schedule, - fakeClockTime: "2017-01-01 12:00:00", - expectedErr: false, - expectedSchedulePhaseUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule, - expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, - expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). - WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule, + name: "schedule with phase New gets validated and triggers a backup", + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseNew).WithCronSchedule("@every 5m").Schedule, + fakeClockTime: "2017-01-01 12:00:00", + expectedErr: false, + expectedPhase: string(api.SchedulePhaseEnabled), + expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, + expectedLastBackup: "2017-01-01 12:00:00", }, { name: "schedule with phase Enabled gets re-validated and triggers a backup if valid", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule, + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled).WithCronSchedule("@every 5m").Schedule, fakeClockTime: "2017-01-01 12:00:00", expectedErr: false, - expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, - expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). - WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule, + expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, + expectedLastBackup: "2017-01-01 12:00:00", }, { name: "schedule that's already run gets LastBackup updated", - schedule: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). + schedule: arktest.NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). WithCronSchedule("@every 5m").WithLastBackupTime("2000-01-01 00:00:00").Schedule, fakeClockTime: "2017-01-01 12:00:00", expectedErr: false, - expectedBackupCreate: NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, - expectedScheduleLastBackupUpdate: NewTestSchedule("ns", "name").WithPhase(api.SchedulePhaseEnabled). - WithCronSchedule("@every 5m").WithLastBackupTime("2017-01-01 12:00:00").Schedule, + expectedBackupCreate: arktest.NewTestBackup().WithNamespace("ns").WithName("name-20170101120000").WithLabel("ark-schedule", "name").Backup, + expectedLastBackup: "2017-01-01 12:00:00", }, } - // flag.Set("logtostderr", "true") - // flag.Set("v", "4") - for _, test := range tests { t.Run(test.name, func(t *testing.T) { var ( client = fake.NewSimpleClientset() sharedInformers = informers.NewSharedInformerFactory(client, 0) - logger, _ = testlogger.NewNullLogger() + logger = arktest.NewLogger() ) c := NewScheduleController( @@ -148,16 +145,36 @@ func TestProcessSchedule(t *testing.T) { if test.schedule != nil { sharedInformers.Ark().V1().Schedules().Informer().GetStore().Add(test.schedule) - // this is necessary so the Update() call returns the appropriate object - client.PrependReactor("update", "schedules", func(action core.Action) (bool, runtime.Object, error) { - obj := action.(core.UpdateAction).GetObject() - // need to deep copy so we can test the schedule state for each call to update - copy, err := scheme.Scheme.DeepCopy(obj) - if err != nil { + // this is necessary so the Patch() call returns the appropriate object + client.PrependReactor("patch", "schedules", func(action core.Action) (bool, runtime.Object, error) { + var ( + patch = action.(core.PatchAction).GetPatch() + patchMap = make(map[string]interface{}) + res = test.schedule.DeepCopy() + ) + + if err := json.Unmarshal(patch, &patchMap); err != nil { + t.Logf("error unmarshalling patch: %s\n", err) return false, nil, err } - ret := copy.(runtime.Object) - return true, ret, nil + + // these are the fields that may be updated by the controller + phase, err := collections.GetString(patchMap, "status.phase") + if err == nil { + res.Status.Phase = api.SchedulePhase(phase) + } + + lastBackupStr, err := collections.GetString(patchMap, "status.lastBackup") + if err == nil { + parsed, err := time.Parse(time.RFC3339, lastBackupStr) + if err != nil { + t.Logf("error parsing status.lastBackup: %s\n", err) + return false, nil, err + } + res.Status.LastBackup = metav1.Time{Time: parsed} + } + + return true, res, nil }) } @@ -171,37 +188,88 @@ func TestProcessSchedule(t *testing.T) { assert.Equal(t, test.expectedErr, err != nil, "got error %v", err) - expectedActions := make([]core.Action, 0) + actions := client.Actions() + index := 0 - if upd := test.expectedSchedulePhaseUpdate; upd != nil { - action := core.NewUpdateAction( - api.SchemeGroupVersion.WithResource("schedules"), - upd.Namespace, - upd) - expectedActions = append(expectedActions, action) + if test.expectedPhase != "" { + require.True(t, len(actions) > index, "len(actions) is too small") + + patchAction, ok := actions[index].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + patch := make(map[string]interface{}) + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + assert.Equal(t, 1, len(patch), "patch has wrong number of keys") + + expectedStatusKeys := 1 + + assert.True(t, collections.HasKeyAndVal(patch, "status.phase", test.expectedPhase), "patch's status.phase does not match") + + if test.expectedValidationError != "" { + errs, err := collections.GetSlice(patch, "status.validationErrors") + require.NoError(t, err, "error getting patch's status.validationErrors") + + require.Equal(t, 1, len(errs)) + + assert.Equal(t, test.expectedValidationError, errs[0].(string), "patch's status.validationErrors does not match") + + expectedStatusKeys++ + } + + res, _ := collections.GetMap(patch, "status") + assert.Equal(t, expectedStatusKeys, len(res), "patch's status has the wrong number of keys") + + index++ } if created := test.expectedBackupCreate; created != nil { + require.True(t, len(actions) > index, "len(actions) is too small") + action := core.NewCreateAction( api.SchemeGroupVersion.WithResource("backups"), created.Namespace, created) - expectedActions = append(expectedActions, action) + + assert.Equal(t, action, actions[index]) + + index++ } - if upd := test.expectedScheduleLastBackupUpdate; upd != nil { - action := core.NewUpdateAction( - api.SchemeGroupVersion.WithResource("schedules"), - upd.Namespace, - upd) - expectedActions = append(expectedActions, action) - } + if test.expectedLastBackup != "" { + require.True(t, len(actions) > index, "len(actions) is too small") - assert.Equal(t, expectedActions, client.Actions()) + patchAction, ok := actions[index].(core.PatchAction) + require.True(t, ok, "action is not a PatchAction") + + patch := make(map[string]interface{}) + require.NoError(t, json.Unmarshal(patchAction.GetPatch(), &patch), "cannot unmarshal patch") + + assert.Equal(t, 1, len(patch), "patch has wrong number of keys") + + lastBackup, _ := collections.GetValue(patch, "status.lastBackup") + fmt.Println(lastBackup) + + assert.True( + t, + collections.HasKeyAndVal(patch, "status.lastBackup", parseTime(test.expectedLastBackup).UTC().Format(time.RFC3339)), + "patch's status.lastBackup does not match", + ) + + res, _ := collections.GetMap(patch, "status") + assert.Equal(t, 1, len(res), "patch's status has the wrong number of keys") + + index++ + } }) } } +func parseTime(timeString string) time.Time { + res, _ := time.Parse("2006-01-02 15:04:05", timeString) + return res +} + func TestGetNextRunTime(t *testing.T) { tests := []struct { name string diff --git a/pkg/util/collections/map_utils.go b/pkg/util/collections/map_utils.go index 8c4f68852..4d3f3dd8c 100644 --- a/pkg/util/collections/map_utils.go +++ b/pkg/util/collections/map_utils.go @@ -122,3 +122,14 @@ func Exists(root map[string]interface{}, path string) bool { _, err := GetValue(root, path) return err == nil } + +// HasKeyAndVal returns true if root[path] exists and the value +// contained is equal to val, or false otherwise. +func HasKeyAndVal(root map[string]interface{}, path string, val interface{}) bool { + valObj, err := GetValue(root, path) + if err != nil { + return false + } + + return valObj == val +}