diff --git a/pkg/controller/pod_volume_backup_controller.go b/pkg/controller/pod_volume_backup_controller.go index 639d113b0..5776653cf 100644 --- a/pkg/controller/pod_volume_backup_controller.go +++ b/pkg/controller/pod_volume_backup_controller.go @@ -38,6 +38,7 @@ import ( listers "github.com/heptio/ark/pkg/generated/listers/ark/v1" "github.com/heptio/ark/pkg/restic" arkexec "github.com/heptio/ark/pkg/util/exec" + "github.com/heptio/ark/pkg/util/filesystem" "github.com/heptio/ark/pkg/util/kube" ) @@ -52,6 +53,7 @@ type podVolumeBackupController struct { nodeName string processBackupFunc func(*arkv1api.PodVolumeBackup) error + fileSystem filesystem.Interface } // NewPodVolumeBackupController creates a new pod volume backup controller. @@ -72,6 +74,8 @@ func NewPodVolumeBackupController( secretLister: corev1listers.NewSecretLister(secretInformer.GetIndexer()), pvcLister: pvcInformer.Lister(), nodeName: nodeName, + + fileSystem: filesystem.NewFileSystem(), } c.syncHandler = c.processQueueItem @@ -194,7 +198,7 @@ func (c *podVolumeBackupController) processBackup(req *arkv1api.PodVolumeBackup) log.WithField("path", path).Debugf("Found path matching glob") // temp creds - file, err := restic.TempCredentialsFile(c.secretLister, req.Namespace, req.Spec.Pod.Namespace) + file, err := restic.TempCredentialsFile(c.secretLister, req.Namespace, req.Spec.Pod.Namespace, c.fileSystem) if err != nil { log.WithError(err).Error("Error creating temp restic credentials file") return c.fail(req, errors.Wrap(err, "error creating temp restic credentials file").Error(), log) diff --git a/pkg/controller/pod_volume_restore_controller.go b/pkg/controller/pod_volume_restore_controller.go index 4a615fb60..78b467240 100644 --- a/pkg/controller/pod_volume_restore_controller.go +++ b/pkg/controller/pod_volume_restore_controller.go @@ -42,6 +42,7 @@ import ( "github.com/heptio/ark/pkg/restic" "github.com/heptio/ark/pkg/util/boolptr" arkexec "github.com/heptio/ark/pkg/util/exec" + "github.com/heptio/ark/pkg/util/filesystem" "github.com/heptio/ark/pkg/util/kube" ) @@ -56,6 +57,7 @@ type podVolumeRestoreController struct { nodeName string processRestoreFunc func(*arkv1api.PodVolumeRestore) error + fileSystem filesystem.Interface } // NewPodVolumeRestoreController creates a new pod volume restore controller. @@ -76,6 +78,8 @@ func NewPodVolumeRestoreController( secretLister: corev1listers.NewSecretLister(secretInformer.GetIndexer()), pvcLister: pvcInformer.Lister(), nodeName: nodeName, + + fileSystem: filesystem.NewFileSystem(), } c.syncHandler = c.processQueueItem @@ -270,7 +274,7 @@ func (c *podVolumeRestoreController) processRestore(req *arkv1api.PodVolumeResto return c.failRestore(req, errors.Wrap(err, "error getting volume directory name").Error(), log) } - credsFile, err := restic.TempCredentialsFile(c.secretLister, req.Namespace, req.Spec.Pod.Namespace) + credsFile, err := restic.TempCredentialsFile(c.secretLister, req.Namespace, req.Spec.Pod.Namespace, c.fileSystem) if err != nil { log.WithError(err).Error("Error creating temp restic credentials file") return c.failRestore(req, errors.Wrap(err, "error creating temp restic credentials file").Error(), log) diff --git a/pkg/restic/command_factory_test.go b/pkg/restic/command_factory_test.go new file mode 100644 index 000000000..de3b35fbf --- /dev/null +++ b/pkg/restic/command_factory_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2018 the Heptio Ark contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package restic + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackupCommand(t *testing.T) { + c := BackupCommand("repo-id", "password-file", "path", map[string]string{"foo": "bar", "c": "d"}) + + assert.Equal(t, "backup", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) + assert.Equal(t, "password-file", c.PasswordFile) + assert.Equal(t, "path", c.Dir) + assert.Equal(t, []string{"."}, c.Args) + + expected := []string{"--tag=foo=bar", "--tag=c=d", "--hostname=ark"} + sort.Strings(expected) + sort.Strings(c.ExtraFlags) + assert.Equal(t, expected, c.ExtraFlags) +} + +func TestRestoreCommand(t *testing.T) { + c := RestoreCommand("repo-id", "password-file", "snapshot-id", "target") + + assert.Equal(t, "restore", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) + assert.Equal(t, "password-file", c.PasswordFile) + assert.Equal(t, "target", c.Dir) + assert.Equal(t, []string{"snapshot-id"}, c.Args) + assert.Equal(t, []string{"--target=."}, c.ExtraFlags) +} + +func TestGetSnapshotCommand(t *testing.T) { + c := GetSnapshotCommand("repo-id", "password-file", map[string]string{"foo": "bar", "c": "d"}) + + assert.Equal(t, "snapshots", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) + assert.Equal(t, "password-file", c.PasswordFile) + + expected := []string{"--json", "--last", "--tag=foo=bar,c=d"} + sort.Strings(expected) + sort.Strings(c.ExtraFlags) + assert.Equal(t, expected, c.ExtraFlags) +} + +func TestInitCommand(t *testing.T) { + c := InitCommand("repo-id") + + assert.Equal(t, "init", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) +} + +func TestCheckCommand(t *testing.T) { + c := CheckCommand("repo-id") + + assert.Equal(t, "check", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) +} + +func TestPruneCommand(t *testing.T) { + c := PruneCommand("repo-id") + + assert.Equal(t, "prune", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) +} + +func TestForgetCommand(t *testing.T) { + c := ForgetCommand("repo-id", "snapshot-id") + + assert.Equal(t, "forget", c.Command) + assert.Equal(t, "repo-id", c.RepoIdentifier) + assert.Equal(t, []string{"snapshot-id"}, c.Args) +} diff --git a/pkg/restic/command_test.go b/pkg/restic/command_test.go new file mode 100644 index 000000000..bf2856c8d --- /dev/null +++ b/pkg/restic/command_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2018 the Heptio Ark contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package restic + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRepoName(t *testing.T) { + c := &Command{RepoIdentifier: ""} + assert.Equal(t, "", c.RepoName()) + + c.RepoIdentifier = "s3:s3.amazonaws.com/bucket/prefix/repo" + assert.Equal(t, "repo", c.RepoName()) + + c.RepoIdentifier = "azure:bucket:/repo" + assert.Equal(t, "repo", c.RepoName()) + + c.RepoIdentifier = "gs:bucket:/prefix/repo" + assert.Equal(t, "repo", c.RepoName()) +} + +func TestStringSlice(t *testing.T) { + c := &Command{ + Command: "cmd", + RepoIdentifier: "repo-id", + PasswordFile: "/path/to/password-file", + Dir: "/some/pwd", + Args: []string{"arg-1", "arg-2"}, + ExtraFlags: []string{"--foo=bar"}, + } + + require.NoError(t, os.Unsetenv("ARK_SCRATCH_DIR")) + assert.Equal(t, []string{ + "restic", + "cmd", + "--repo=repo-id", + "--password-file=/path/to/password-file", + "arg-1", + "arg-2", + "--foo=bar", + }, c.StringSlice()) + + os.Setenv("ARK_SCRATCH_DIR", "/foo") + assert.Equal(t, []string{ + "restic", + "cmd", + "--repo=repo-id", + "--password-file=/path/to/password-file", + "--cache-dir=/foo/.cache/restic", + "arg-1", + "arg-2", + "--foo=bar", + }, c.StringSlice()) + + require.NoError(t, os.Unsetenv("ARK_SCRATCH_DIR")) +} + +func TestString(t *testing.T) { + c := &Command{ + Command: "cmd", + RepoIdentifier: "repo-id", + PasswordFile: "/path/to/password-file", + Dir: "/some/pwd", + Args: []string{"arg-1", "arg-2"}, + ExtraFlags: []string{"--foo=bar"}, + } + + require.NoError(t, os.Unsetenv("ARK_SCRATCH_DIR")) + assert.Equal(t, "restic cmd --repo=repo-id --password-file=/path/to/password-file arg-1 arg-2 --foo=bar", c.String()) +} + +func TestCmd(t *testing.T) { + c := &Command{ + Command: "cmd", + RepoIdentifier: "repo-id", + PasswordFile: "/path/to/password-file", + Dir: "/some/pwd", + Args: []string{"arg-1", "arg-2"}, + ExtraFlags: []string{"--foo=bar"}, + } + + require.NoError(t, os.Unsetenv("ARK_SCRATCH_DIR")) + execCmd := c.Cmd() + + assert.Equal(t, c.StringSlice(), execCmd.Args) + assert.Equal(t, c.Dir, execCmd.Dir) +} diff --git a/pkg/restic/common.go b/pkg/restic/common.go index 08b606e42..4c2932c2c 100644 --- a/pkg/restic/common.go +++ b/pkg/restic/common.go @@ -18,7 +18,6 @@ package restic import ( "fmt" - "io/ioutil" "strings" "time" @@ -30,6 +29,7 @@ import ( arkv1api "github.com/heptio/ark/pkg/apis/ark/v1" arkv1listers "github.com/heptio/ark/pkg/generated/listers/ark/v1" + "github.com/heptio/ark/pkg/util/filesystem" ) const ( @@ -144,7 +144,7 @@ func GetSnapshotsInBackup(backup *arkv1api.Backup, podVolumeBackupLister arkv1li // encryption key for the given repo and returns its path. The // caller should generally call os.Remove() to remove the file // when done with it. -func TempCredentialsFile(secretLister corev1listers.SecretLister, arkNamespace, repoName string) (string, error) { +func TempCredentialsFile(secretLister corev1listers.SecretLister, arkNamespace, repoName string, fs filesystem.Interface) (string, error) { secretGetter := NewListerSecretGetter(secretLister) // For now, all restic repos share the same key so we don't need the repoName to fetch it. @@ -156,7 +156,7 @@ func TempCredentialsFile(secretLister corev1listers.SecretLister, arkNamespace, return "", err } - file, err := ioutil.TempFile("", fmt.Sprintf("%s-%s", CredentialsSecretName, repoName)) + file, err := fs.TempFile("", fmt.Sprintf("%s-%s", CredentialsSecretName, repoName)) if err != nil { return "", errors.WithStack(err) } diff --git a/pkg/restic/common_test.go b/pkg/restic/common_test.go new file mode 100644 index 000000000..f5ea1d166 --- /dev/null +++ b/pkg/restic/common_test.go @@ -0,0 +1,366 @@ +/* +Copyright 2018 the Heptio Ark contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package restic + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" + + corev1api "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + corev1listers "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/cache" + + arkv1api "github.com/heptio/ark/pkg/apis/ark/v1" + "github.com/heptio/ark/pkg/generated/clientset/versioned/fake" + informers "github.com/heptio/ark/pkg/generated/informers/externalversions" + arktest "github.com/heptio/ark/pkg/util/test" +) + +func TestPodHasSnapshotAnnotation(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expected bool + }{ + { + name: "nil annotations", + annotations: nil, + expected: false, + }, + { + name: "empty annotations", + annotations: make(map[string]string), + expected: false, + }, + { + name: "non-empty map, no snapshot annotation", + annotations: map[string]string{"foo": "bar"}, + expected: false, + }, + { + name: "has snapshot annotation only, no suffix", + annotations: map[string]string{podAnnotationPrefix: "bar"}, + expected: true, + }, + { + name: "has snapshot annotation only, with suffix", + annotations: map[string]string{podAnnotationPrefix + "foo": "bar"}, + expected: true, + }, + { + name: "has snapshot annotation, with suffix", + annotations: map[string]string{"foo": "bar", podAnnotationPrefix + "foo": "bar"}, + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pod := &corev1api.Pod{} + pod.Annotations = test.annotations + assert.Equal(t, test.expected, PodHasSnapshotAnnotation(pod)) + }) + } +} + +func TestGetPodSnapshotAnnotations(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expected map[string]string + }{ + { + name: "nil annotations", + annotations: nil, + expected: nil, + }, + { + name: "empty annotations", + annotations: make(map[string]string), + expected: nil, + }, + { + name: "non-empty map, no snapshot annotation", + annotations: map[string]string{"foo": "bar"}, + expected: nil, + }, + { + name: "has snapshot annotation only, no suffix", + annotations: map[string]string{podAnnotationPrefix: "bar"}, + expected: map[string]string{"": "bar"}, + }, + { + name: "has snapshot annotation only, with suffix", + annotations: map[string]string{podAnnotationPrefix + "foo": "bar"}, + expected: map[string]string{"foo": "bar"}, + }, + { + name: "has snapshot annotation, with suffix", + annotations: map[string]string{"x": "y", podAnnotationPrefix + "foo": "bar", podAnnotationPrefix + "abc": "123"}, + expected: map[string]string{"foo": "bar", "abc": "123"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pod := &corev1api.Pod{} + pod.Annotations = test.annotations + assert.Equal(t, test.expected, GetPodSnapshotAnnotations(pod)) + }) + } +} + +func TestSetPodSnapshotAnnotation(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + volumeName string + snapshotID string + expected map[string]string + }{ + { + name: "set snapshot annotation on pod with no annotations", + annotations: nil, + volumeName: "foo", + snapshotID: "bar", + expected: map[string]string{podAnnotationPrefix + "foo": "bar"}, + }, + { + name: "set snapshot annotation on pod with existing annotations", + annotations: map[string]string{"existing": "annotation"}, + volumeName: "foo", + snapshotID: "bar", + expected: map[string]string{"existing": "annotation", podAnnotationPrefix + "foo": "bar"}, + }, + { + name: "snapshot annotation is overwritten if already exists", + annotations: map[string]string{podAnnotationPrefix + "foo": "existing"}, + volumeName: "foo", + snapshotID: "bar", + expected: map[string]string{podAnnotationPrefix + "foo": "bar"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pod := &corev1api.Pod{} + pod.Annotations = test.annotations + + SetPodSnapshotAnnotation(pod, test.volumeName, test.snapshotID) + assert.Equal(t, test.expected, pod.Annotations) + }) + } +} + +func TestGetVolumesToBackup(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + expected []string + }{ + { + name: "nil annotations", + annotations: nil, + expected: nil, + }, + { + name: "no volumes to backup", + annotations: map[string]string{"foo": "bar"}, + expected: nil, + }, + { + name: "one volume to backup", + annotations: map[string]string{"foo": "bar", volumesToBackupAnnotation: "volume-1"}, + expected: []string{"volume-1"}, + }, + { + name: "multiple volumes to backup", + annotations: map[string]string{"foo": "bar", volumesToBackupAnnotation: "volume-1,volume-2,volume-3"}, + expected: []string{"volume-1", "volume-2", "volume-3"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pod := &corev1api.Pod{} + pod.Annotations = test.annotations + + res := GetVolumesToBackup(pod) + + // sort to ensure good compare of slices + sort.Strings(test.expected) + sort.Strings(res) + + assert.Equal(t, test.expected, res) + }) + } +} + +func TestGetSnapshotsInBackup(t *testing.T) { + tests := []struct { + name string + podVolumeBackups []arkv1api.PodVolumeBackup + expected []SnapshotIdentifier + }{ + { + name: "no pod volume backups", + podVolumeBackups: nil, + expected: nil, + }, + { + name: "no pod volume backups with matching label", + podVolumeBackups: []arkv1api.PodVolumeBackup{ + { + ObjectMeta: metav1.ObjectMeta{Name: "foo", Labels: map[string]string{arkv1api.BackupNameLabel: "non-matching-backup-1"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-1", Namespace: "ns-1"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "bar", Labels: map[string]string{arkv1api.BackupNameLabel: "non-matching-backup-2"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-2", Namespace: "ns-2"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-2"}, + }, + }, + expected: nil, + }, + { + name: "some pod volume backups with matching label", + podVolumeBackups: []arkv1api.PodVolumeBackup{ + { + ObjectMeta: metav1.ObjectMeta{Name: "foo", Labels: map[string]string{arkv1api.BackupNameLabel: "non-matching-backup-1"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-1", Namespace: "ns-1"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "bar", Labels: map[string]string{arkv1api.BackupNameLabel: "non-matching-backup-2"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-2", Namespace: "ns-2"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-2"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "completed-pvb", Labels: map[string]string{arkv1api.BackupNameLabel: "backup-1"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-1", Namespace: "ns-1"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-3"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "completed-pvb-2", Labels: map[string]string{arkv1api.BackupNameLabel: "backup-1"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-1", Namespace: "ns-1"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: "snap-4"}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "incomplete-or-failed-pvb", Labels: map[string]string{arkv1api.BackupNameLabel: "backup-1"}}, + Spec: arkv1api.PodVolumeBackupSpec{ + Pod: corev1api.ObjectReference{Name: "pod-1", Namespace: "ns-2"}, + }, + Status: arkv1api.PodVolumeBackupStatus{SnapshotID: ""}, + }, + }, + expected: []SnapshotIdentifier{ + { + Repo: "ns-1", + SnapshotID: "snap-3", + }, + { + Repo: "ns-1", + SnapshotID: "snap-4", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + client = fake.NewSimpleClientset() + sharedInformers = informers.NewSharedInformerFactory(client, 0) + pvbInformer = sharedInformers.Ark().V1().PodVolumeBackups() + arkBackup = &arkv1api.Backup{} + ) + + arkBackup.Name = "backup-1" + + for _, pvb := range test.podVolumeBackups { + require.NoError(t, pvbInformer.Informer().GetStore().Add(pvb.DeepCopy())) + } + + res, err := GetSnapshotsInBackup(arkBackup, pvbInformer.Lister()) + assert.NoError(t, err) + + // sort to ensure good compare of slices + less := func(snapshots []SnapshotIdentifier) func(i, j int) bool { + return func(i, j int) bool { + return snapshots[i].Repo < snapshots[j].Repo && + snapshots[i].SnapshotID < snapshots[j].SnapshotID + } + + } + sort.Slice(test.expected, less(test.expected)) + sort.Slice(res, less(res)) + + assert.Equal(t, test.expected, res) + }) + } +} + +func TestTempCredentialsFile(t *testing.T) { + var ( + secretInformer = cache.NewSharedIndexInformer(nil, new(corev1api.Secret), 0, cache.Indexers{cache.NamespaceIndex: cache.MetaNamespaceIndexFunc}) + secretLister = corev1listers.NewSecretLister(secretInformer.GetIndexer()) + fs = arktest.NewFakeFileSystem() + secret = &corev1api.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "heptio-ark", + Name: CredentialsSecretName, + }, + Data: map[string][]byte{ + CredentialsKey: []byte("passw0rd"), + }, + } + ) + + // secret not in lister: expect an error + fileName, err := TempCredentialsFile(secretLister, "heptio-ark", "default", fs) + assert.Error(t, err) + + // now add secret to lister + require.NoError(t, secretInformer.GetStore().Add(secret)) + + // secret in lister: expect temp file to be created with password + fileName, err = TempCredentialsFile(secretLister, "heptio-ark", "default", fs) + require.NoError(t, err) + + contents, err := fs.ReadFile(fileName) + require.NoError(t, err) + + assert.Equal(t, "passw0rd", string(contents)) +} diff --git a/pkg/restic/config.go b/pkg/restic/config.go index 0e8d9fccf..7a9dc9f11 100644 --- a/pkg/restic/config.go +++ b/pkg/restic/config.go @@ -32,6 +32,10 @@ const ( GCPBackend BackendType = "gcp" ) +// this func is assigned to a package-level variable so it can be +// replaced when unit-testing +var getAWSBucketRegion = aws.GetBucketRegion + // getRepoPrefix returns the prefix of the value of the --repo flag for // restic commands, i.e. everything except the "/". func getRepoPrefix(config arkv1api.ObjectStorageProviderConfig) string { @@ -55,7 +59,7 @@ func getRepoPrefix(config arkv1api.ObjectStorageProviderConfig) string { case config.Config["s3Url"] != "": url = config.Config["s3Url"] default: - region, err := aws.GetBucketRegion(bucket) + region, err := getAWSBucketRegion(bucket) if err != nil { url = "s3.amazonaws.com" break diff --git a/pkg/restic/config_test.go b/pkg/restic/config_test.go new file mode 100644 index 000000000..9130b717c --- /dev/null +++ b/pkg/restic/config_test.go @@ -0,0 +1,76 @@ +/* +Copyright 2018 the Heptio Ark contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package restic + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + arkv1api "github.com/heptio/ark/pkg/apis/ark/v1" +) + +func TestGetRepoIdentifier(t *testing.T) { + // if getAWSBucketRegion returns an error, use default "s3.amazonaws.com/..." URL + getAWSBucketRegion = func(string) (string, error) { + return "", errors.New("no region found") + } + config := arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{Name: "aws"}, + ResticLocation: "bucket/prefix", + } + assert.Equal(t, "s3:s3.amazonaws.com/bucket/prefix/repo-1", GetRepoIdentifier(config, "repo-1")) + + // stub implementation of getAWSBucketRegion + getAWSBucketRegion = func(string) (string, error) { + return "us-west-2", nil + } + + config = arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{Name: "aws"}, + ResticLocation: "bucket", + } + assert.Equal(t, "s3:s3-us-west-2.amazonaws.com/bucket/repo-1", GetRepoIdentifier(config, "repo-1")) + + config = arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{Name: "aws"}, + ResticLocation: "bucket/prefix", + } + assert.Equal(t, "s3:s3-us-west-2.amazonaws.com/bucket/prefix/repo-1", GetRepoIdentifier(config, "repo-1")) + + config = arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{ + Name: "aws", + Config: map[string]string{"s3Url": "alternate-url"}, + }, + ResticLocation: "bucket/prefix", + } + assert.Equal(t, "s3:alternate-url/bucket/prefix/repo-1", GetRepoIdentifier(config, "repo-1")) + + config = arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{Name: "azure"}, + ResticLocation: "bucket/prefix", + } + assert.Equal(t, "azure:bucket:/prefix/repo-1", GetRepoIdentifier(config, "repo-1")) + + config = arkv1api.ObjectStorageProviderConfig{ + CloudProviderConfig: arkv1api.CloudProviderConfig{Name: "gcp"}, + ResticLocation: "bucket-2/prefix-2", + } + assert.Equal(t, "gs:bucket-2:/prefix-2/repo-2", GetRepoIdentifier(config, "repo-2")) +} diff --git a/pkg/restic/repository_manager.go b/pkg/restic/repository_manager.go index fefa97be5..718e5c63a 100644 --- a/pkg/restic/repository_manager.go +++ b/pkg/restic/repository_manager.go @@ -34,6 +34,7 @@ import ( arkv1informers "github.com/heptio/ark/pkg/generated/informers/externalversions/ark/v1" arkv1listers "github.com/heptio/ark/pkg/generated/listers/ark/v1" arkexec "github.com/heptio/ark/pkg/util/exec" + "github.com/heptio/ark/pkg/util/filesystem" ) // RepositoryManager executes commands against restic repositories. @@ -79,6 +80,7 @@ type repositoryManager struct { log logrus.FieldLogger repoLocker *repoLocker repoEnsurer *repositoryEnsurer + fileSystem filesystem.Interface } // NewRepositoryManager constructs a RepositoryManager. @@ -101,6 +103,7 @@ func NewRepositoryManager( repoLocker: newRepoLocker(), repoEnsurer: newRepositoryEnsurer(repoInformer, repoClient, log), + fileSystem: filesystem.NewFileSystem(), } if !cache.WaitForCacheSync(ctx.Done(), secretsInformer.HasSynced) { @@ -198,7 +201,7 @@ func (rm *repositoryManager) Forget(ctx context.Context, snapshot SnapshotIdenti } func (rm *repositoryManager) exec(cmd *Command) error { - file, err := TempCredentialsFile(rm.secretsLister, rm.namespace, cmd.RepoName()) + file, err := TempCredentialsFile(rm.secretsLister, rm.namespace, cmd.RepoName(), rm.fileSystem) if err != nil { return err } diff --git a/pkg/util/filesystem/file_system.go b/pkg/util/filesystem/file_system.go index b1b9d857f..d8fbc472b 100644 --- a/pkg/util/filesystem/file_system.go +++ b/pkg/util/filesystem/file_system.go @@ -32,6 +32,13 @@ type Interface interface { ReadDir(dirname string) ([]os.FileInfo, error) ReadFile(filename string) ([]byte, error) DirExists(path string) (bool, error) + TempFile(dir, prefix string) (NameWriteCloser, error) +} + +type NameWriteCloser interface { + io.WriteCloser + + Name() string } func NewFileSystem() Interface { @@ -74,3 +81,7 @@ func (fs *osFileSystem) DirExists(path string) (bool, error) { } return false, err } + +func (fs *osFileSystem) TempFile(dir, prefix string) (NameWriteCloser, error) { + return ioutil.TempFile(dir, prefix) +} diff --git a/pkg/util/test/fake_file_system.go b/pkg/util/test/fake_file_system.go index 0603bba30..9717ef65b 100644 --- a/pkg/util/test/fake_file_system.go +++ b/pkg/util/test/fake_file_system.go @@ -5,6 +5,8 @@ import ( "os" "github.com/spf13/afero" + + "github.com/heptio/ark/pkg/util/filesystem" ) type FakeFileSystem struct { @@ -68,3 +70,7 @@ func (fs *FakeFileSystem) WithDirectories(path ...string) *FakeFileSystem { return fs } + +func (fs *FakeFileSystem) TempFile(dir, prefix string) (filesystem.NameWriteCloser, error) { + return afero.TempFile(fs.fs, dir, prefix) +}