diff --git a/pkg/cloudprovider/azure/block_store.go b/pkg/cloudprovider/azure/block_store.go index b99ff1301..6e818a321 100644 --- a/pkg/cloudprovider/azure/block_store.go +++ b/pkg/cloudprovider/azure/block_store.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "regexp" + "strings" "time" "github.com/Azure/azure-sdk-for-go/arm/disk" @@ -129,7 +130,7 @@ func (b *blockStore) CreateVolumeFromSnapshot(snapshotID, volumeType, volumeAZ s return "", err } - // Lookup snapshot info for its Location + // Lookup snapshot info for its Location & Tags so we can apply them to the volume snapshotInfo, err := b.snaps.Get(snapshotIdentifier.resourceGroup, snapshotIdentifier.name) if err != nil { return "", errors.WithStack(err) @@ -147,6 +148,7 @@ func (b *blockStore) CreateVolumeFromSnapshot(snapshotID, volumeType, volumeAZ s }, AccountType: disk.StorageAccountTypes(volumeType), }, + Tags: snapshotInfo.Tags, } ctx, cancel := context.WithTimeout(context.Background(), b.apiTimeout) @@ -210,15 +212,10 @@ func (b *blockStore) CreateSnapshot(volumeID, volumeAZ string, tags map[string]s SourceResourceID: &fullDiskName, }, }, - Tags: &map[string]*string{}, + Tags: getSnapshotTags(tags, diskInfo.Tags), Location: diskInfo.Location, } - for k, v := range tags { - val := v - (*snap.Tags)[k] = &val - } - ctx, cancel := context.WithTimeout(context.Background(), b.apiTimeout) defer cancel() @@ -232,6 +229,37 @@ func (b *blockStore) CreateSnapshot(volumeID, volumeAZ string, tags map[string]s return getComputeResourceName(b.subscription, b.resourceGroup, snapshotsResource, snapshotName), nil } +func getSnapshotTags(arkTags map[string]string, diskTags *map[string]*string) *map[string]*string { + if diskTags == nil && len(arkTags) == 0 { + return nil + } + + snapshotTags := make(map[string]*string) + + // copy tags from disk to snapshot + if diskTags != nil { + for k, v := range *diskTags { + snapshotTags[k] = stringPtr(*v) + } + } + + // merge Ark-assigned tags with the disk's tags (note that we want current + // Ark-assigned tags to overwrite any older versions of them that may exist + // due to prior snapshots/restores) + for k, v := range arkTags { + // Azure does not allow slashes in tag keys, so replace + // with dash (inline with what Kubernetes does) + key := strings.Replace(k, "/", "-", -1) + snapshotTags[key] = stringPtr(v) + } + + return &snapshotTags +} + +func stringPtr(s string) *string { + return &s +} + func (b *blockStore) DeleteSnapshot(snapshotID string) error { ctx, cancel := context.WithTimeout(context.Background(), b.apiTimeout) defer cancel() diff --git a/pkg/cloudprovider/azure/block_store_test.go b/pkg/cloudprovider/azure/block_store_test.go index 785c99b5e..9cb038747 100644 --- a/pkg/cloudprovider/azure/block_store_test.go +++ b/pkg/cloudprovider/azure/block_store_test.go @@ -111,3 +111,96 @@ func TestGetComputeResourceName(t *testing.T) { assert.Equal(t, "/subscriptions/sub-1/resourceGroups/rg-1/providers/Microsoft.Compute/snapshots/snap-1", getComputeResourceName("sub-1", "rg-1", snapshotsResource, "snap-1")) } + +func TestGetSnapshotTags(t *testing.T) { + tests := []struct { + name string + arkTags map[string]string + diskTags *map[string]*string + expected *map[string]*string + }{ + { + name: "degenerate case (no tags)", + arkTags: nil, + diskTags: nil, + expected: nil, + }, + { + name: "ark tags only get applied", + arkTags: map[string]string{ + "ark-key1": "ark-val1", + "ark-key2": "ark-val2", + }, + diskTags: nil, + expected: &map[string]*string{ + "ark-key1": stringPtr("ark-val1"), + "ark-key2": stringPtr("ark-val2"), + }, + }, + { + name: "slashes in ark tag keys get replaces with dashes", + arkTags: map[string]string{ + "ark/key1": "ark-val1", + "ark/key/2": "ark-val2", + }, + diskTags: nil, + expected: &map[string]*string{ + "ark-key1": stringPtr("ark-val1"), + "ark-key-2": stringPtr("ark-val2"), + }, + }, + { + name: "volume tags only get applied", + arkTags: nil, + diskTags: &map[string]*string{ + "azure-key1": stringPtr("azure-val1"), + "azure-key2": stringPtr("azure-val2"), + }, + expected: &map[string]*string{ + "azure-key1": stringPtr("azure-val1"), + "azure-key2": stringPtr("azure-val2"), + }, + }, + { + name: "non-overlapping ark and volume tags both get applied", + arkTags: map[string]string{"ark-key": "ark-val"}, + diskTags: &map[string]*string{"azure-key": stringPtr("azure-val")}, + expected: &map[string]*string{ + "ark-key": stringPtr("ark-val"), + "azure-key": stringPtr("azure-val"), + }, + }, + { + name: "when tags overlap, ark tags take precedence", + arkTags: map[string]string{ + "ark-key": "ark-val", + "overlapping-key": "ark-val", + }, + diskTags: &map[string]*string{ + "azure-key": stringPtr("azure-val"), + "overlapping-key": stringPtr("azure-val"), + }, + expected: &map[string]*string{ + "ark-key": stringPtr("ark-val"), + "azure-key": stringPtr("azure-val"), + "overlapping-key": stringPtr("ark-val"), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res := getSnapshotTags(test.arkTags, test.diskTags) + + if test.expected == nil { + assert.Nil(t, res) + return + } + + assert.Equal(t, len(*test.expected), len(*res)) + for k, v := range *test.expected { + assert.Equal(t, v, (*res)[k]) + } + }) + } +}