diff --git a/pkg/cloudprovider/azure/block_store.go b/pkg/cloudprovider/azure/block_store.go index 591fe3cee..b99ff1301 100644 --- a/pkg/cloudprovider/azure/block_store.go +++ b/pkg/cloudprovider/azure/block_store.go @@ -20,7 +20,7 @@ import ( "context" "fmt" "os" - "strings" + "regexp" "time" "github.com/Azure/azure-sdk-for-go/arm/disk" @@ -29,6 +29,7 @@ import ( "github.com/Azure/go-autorest/autorest/azure" "github.com/pkg/errors" "github.com/satori/uuid" + "k8s.io/apimachinery/pkg/runtime" "github.com/heptio/ark/pkg/cloudprovider" @@ -36,15 +37,16 @@ import ( ) const ( - azureClientIDKey string = "AZURE_CLIENT_ID" - azureClientSecretKey string = "AZURE_CLIENT_SECRET" - azureSubscriptionIDKey string = "AZURE_SUBSCRIPTION_ID" - azureTenantIDKey string = "AZURE_TENANT_ID" - azureStorageAccountIDKey string = "AZURE_STORAGE_ACCOUNT_ID" - azureStorageKeyKey string = "AZURE_STORAGE_KEY" - azureResourceGroupKey string = "AZURE_RESOURCE_GROUP" - - apiTimeoutKey = "apiTimeout" + azureClientIDKey = "AZURE_CLIENT_ID" + azureClientSecretKey = "AZURE_CLIENT_SECRET" + azureSubscriptionIDKey = "AZURE_SUBSCRIPTION_ID" + azureTenantIDKey = "AZURE_TENANT_ID" + azureStorageAccountIDKey = "AZURE_STORAGE_ACCOUNT_ID" + azureStorageKeyKey = "AZURE_STORAGE_KEY" + azureResourceGroupKey = "AZURE_RESOURCE_GROUP" + apiTimeoutKey = "apiTimeout" + snapshotsResource = "snapshots" + disksResource = "disks" ) type blockStore struct { @@ -55,6 +57,12 @@ type blockStore struct { apiTimeout time.Duration } +type snapshotIdentifier struct { + subscription string + resourceGroup string + name string +} + func getConfig() map[string]string { cfg := map[string]string{ azureClientIDKey: "", @@ -116,13 +124,17 @@ func (b *blockStore) Init(config map[string]string) error { } func (b *blockStore) CreateVolumeFromSnapshot(snapshotID, volumeType, volumeAZ string, iops *int64) (string, error) { + snapshotIdentifier, err := parseFullSnapshotName(snapshotID) + if err != nil { + return "", err + } + // Lookup snapshot info for its Location - snapshotInfo, err := b.snaps.Get(b.resourceGroup, snapshotID) + snapshotInfo, err := b.snaps.Get(snapshotIdentifier.resourceGroup, snapshotIdentifier.name) if err != nil { return "", errors.WithStack(err) } - fullSnapshotName := getFullSnapshotName(b.subscription, b.resourceGroup, snapshotID) diskName := "restore-" + uuid.NewV4().String() disk := disk.Model{ @@ -131,7 +143,7 @@ func (b *blockStore) CreateVolumeFromSnapshot(snapshotID, volumeType, volumeAZ s Properties: &disk.Properties{ CreationData: &disk.CreationData{ CreateOption: disk.Copy, - SourceResourceID: &fullSnapshotName, + SourceResourceID: &snapshotID, }, AccountType: disk.StorageAccountTypes(volumeType), }, @@ -179,7 +191,7 @@ func (b *blockStore) CreateSnapshot(volumeID, volumeAZ string, tags map[string]s return "", errors.WithStack(err) } - fullDiskName := getFullDiskName(b.subscription, b.resourceGroup, volumeID) + fullDiskName := getComputeResourceName(b.subscription, b.resourceGroup, disksResource, volumeID) // snapshot names must be <= 80 characters long var snapshotName string suffix := "-" + uuid.NewV4().String() @@ -211,14 +223,13 @@ func (b *blockStore) CreateSnapshot(volumeID, volumeAZ string, tags map[string]s defer cancel() _, errChan := b.snaps.CreateOrUpdate(b.resourceGroup, *snap.Name, snap, ctx.Done()) - err = <-errChan if err != nil { return "", errors.WithStack(err) } - return snapshotName, nil + return getComputeResourceName(b.subscription, b.resourceGroup, snapshotsResource, snapshotName), nil } func (b *blockStore) DeleteSnapshot(snapshotID string) error { @@ -232,12 +243,39 @@ func (b *blockStore) DeleteSnapshot(snapshotID string) error { return errors.WithStack(err) } -func getFullDiskName(subscription string, resourceGroup string, diskName string) string { - return fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Compute/disks/%v", subscription, resourceGroup, diskName) +func getComputeResourceName(subscription, resourceGroup, resource, name string) string { + return fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/%s/%s", subscription, resourceGroup, resource, name) } -func getFullSnapshotName(subscription string, resourceGroup string, snapshotName string) string { - return fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Compute/snapshots/%v", subscription, resourceGroup, snapshotName) +var snapshotURIRegexp = regexp.MustCompile( + `^\/subscriptions\/(?P.*)\/resourceGroups\/(?P.*)\/providers\/Microsoft.Compute\/snapshots\/(?P.*)$`) + +// parseFullSnapshotName takes a snapshot URI and returns a snapshot identifier +// or an error if the URI does not match the regexp. +func parseFullSnapshotName(name string) (*snapshotIdentifier, error) { + submatches := snapshotURIRegexp.FindStringSubmatch(name) + if len(submatches) != len(snapshotURIRegexp.SubexpNames()) { + return nil, errors.New("snapshot URI could not be parsed") + } + + snapshotID := &snapshotIdentifier{} + + // capture names start at index 1 to line up with the corresponding indexes + // of submatches (see godoc on SubexpNames()) + for i, names := 1, snapshotURIRegexp.SubexpNames(); i < len(names); i++ { + switch names[i] { + case "subscription": + snapshotID.subscription = submatches[i] + case "resourceGroup": + snapshotID.resourceGroup = submatches[i] + case "snapshotName": + snapshotID.name = submatches[i] + default: + return nil, errors.New("unexpected named capture from snapshot URI regex") + } + } + + return snapshotID, nil } func (b *blockStore) GetVolumeID(pv runtime.Unstructured) (string, error) { @@ -259,16 +297,8 @@ func (b *blockStore) SetVolumeID(pv runtime.Unstructured, volumeID string) (runt return nil, err } - if uri, err := collections.GetString(azure, "diskURI"); err == nil { - previousVolumeID, err := collections.GetString(azure, "diskName") - if err != nil { - return nil, err - } - - azure["diskURI"] = strings.Replace(uri, previousVolumeID, volumeID, -1) - } - azure["diskName"] = volumeID + azure["diskURI"] = getComputeResourceName(b.subscription, b.resourceGroup, disksResource, volumeID) return pv, nil } diff --git a/pkg/cloudprovider/azure/block_store_test.go b/pkg/cloudprovider/azure/block_store_test.go index 2c845d737..785c99b5e 100644 --- a/pkg/cloudprovider/azure/block_store_test.go +++ b/pkg/cloudprovider/azure/block_store_test.go @@ -53,7 +53,10 @@ func TestGetVolumeID(t *testing.T) { } func TestSetVolumeID(t *testing.T) { - b := &blockStore{} + b := &blockStore{ + resourceGroup: "rg", + subscription: "sub", + } pv := &unstructured.Unstructured{} @@ -71,7 +74,9 @@ func TestSetVolumeID(t *testing.T) { actual, err := collections.GetString(updatedPV.UnstructuredContent(), "spec.azureDisk.diskName") require.NoError(t, err) assert.Equal(t, "updated", actual) - assert.NotContains(t, azure, "diskURI") + actual, err = collections.GetString(updatedPV.UnstructuredContent(), "spec.azureDisk.diskURI") + require.NoError(t, err) + assert.Equal(t, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/disks/updated", actual) // with diskURI azure["diskURI"] = "/foo/bar/updated/blarg" @@ -82,5 +87,27 @@ func TestSetVolumeID(t *testing.T) { assert.Equal(t, "revised", actual) actual, err = collections.GetString(updatedPV.UnstructuredContent(), "spec.azureDisk.diskURI") require.NoError(t, err) - assert.Equal(t, "/foo/bar/revised/blarg", actual) + assert.Equal(t, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/disks/revised", actual) +} + +func TestParseFullSnapshotName(t *testing.T) { + // invalid name + fullName := "foo/bar" + _, err := parseFullSnapshotName(fullName) + assert.Error(t, err) + + // valid name + fullName = "/subscriptions/sub-1/resourceGroups/rg-1/providers/Microsoft.Compute/snapshots/snap-1" + snap, err := parseFullSnapshotName(fullName) + require.NoError(t, err) + + assert.Equal(t, "sub-1", snap.subscription) + assert.Equal(t, "rg-1", snap.resourceGroup) + assert.Equal(t, "snap-1", snap.name) +} + +func TestGetComputeResourceName(t *testing.T) { + assert.Equal(t, "/subscriptions/sub-1/resourceGroups/rg-1/providers/Microsoft.Compute/disks/disk-1", getComputeResourceName("sub-1", "rg-1", disksResource, "disk-1")) + + assert.Equal(t, "/subscriptions/sub-1/resourceGroups/rg-1/providers/Microsoft.Compute/snapshots/snap-1", getComputeResourceName("sub-1", "rg-1", snapshotsResource, "snap-1")) }