diff --git a/changelogs/unreleased/7175-blackpiglet b/changelogs/unreleased/7175-blackpiglet new file mode 100644 index 000000000..4ca4de1e3 --- /dev/null +++ b/changelogs/unreleased/7175-blackpiglet @@ -0,0 +1 @@ +Refactor DownloadRequest Stream function \ No newline at end of file diff --git a/pkg/cmd/util/downloadrequest/downloadrequest.go b/pkg/cmd/util/downloadrequest/downloadrequest.go index 12df30f32..7a85920fa 100644 --- a/pkg/cmd/util/downloadrequest/downloadrequest.go +++ b/pkg/cmd/util/downloadrequest/downloadrequest.go @@ -30,10 +30,9 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "k8s.io/apimachinery/pkg/util/wait" kbclient "sigs.k8s.io/controller-runtime/pkg/client" - velerov1api "github.com/vmware-tanzu/velero/pkg/apis/velero/v1" + veleroV1api "github.com/vmware-tanzu/velero/pkg/apis/velero/v1" "github.com/vmware-tanzu/velero/pkg/builder" ) @@ -42,53 +41,75 @@ import ( var ErrNotFound = errors.New("file not found") var ErrDownloadRequestDownloadURLTimeout = errors.New("download request download url timeout, check velero server logs for errors. backup storage location may not be available") -func Stream(ctx context.Context, kbClient kbclient.Client, namespace, name string, kind velerov1api.DownloadTargetKind, w io.Writer, timeout time.Duration, insecureSkipTLSVerify bool, caCertFile string) error { +func Stream( + ctx context.Context, + kbClient kbclient.Client, + namespace, name string, + kind veleroV1api.DownloadTargetKind, + w io.Writer, + timeout time.Duration, + insecureSkipTLSVerify bool, + caCertFile string, +) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + downloadURL, err := getDownloadURL(ctx, kbClient, namespace, name, kind) + if err != nil { + return err + } + + if err := download(ctx, downloadURL, kind, w, insecureSkipTLSVerify, caCertFile); err != nil { + return err + } + + return nil +} + +func getDownloadURL( + ctx context.Context, + kbClient kbclient.Client, + namespace, name string, + kind veleroV1api.DownloadTargetKind, +) (string, error) { uuid, err := uuid.NewRandom() if err != nil { - return errors.WithStack(err) + return "", err } reqName := fmt.Sprintf("%s-%s", name, uuid.String()) created := builder.ForDownloadRequest(namespace, reqName).Target(kind, name).Result() - if err := kbClient.Create(context.Background(), created, &kbclient.CreateOptions{}); err != nil { - return errors.WithStack(err) + if err := kbClient.Create(ctx, created, &kbclient.CreateOptions{}); err != nil { + return "", errors.WithStack(err) } - ctx, cancel := context.WithCancel(ctx) - defer cancel() + for { + select { + case <-ctx.Done(): + return "", ErrDownloadRequestDownloadURLTimeout - key := kbclient.ObjectKey{Name: created.Name, Namespace: namespace} - timeStreamFirstCheck := time.Now() - downloadURLTimeout := false - checkFunc := func() { - // if timeout has been reached, cancel request - if time.Now().After(timeStreamFirstCheck.Add(timeout)) { - downloadURLTimeout = true - cancel() - } - updated := &velerov1api.DownloadRequest{} - if err := kbClient.Get(ctx, key, updated); err != nil { - return - } + case <-time.After(25 * time.Millisecond): + updated := &veleroV1api.DownloadRequest{} + if err := kbClient.Get(ctx, kbclient.ObjectKey{Name: created.Name, Namespace: namespace}, updated); err != nil { + return "", errors.WithStack(err) + } - // TODO: once the minimum supported Kubernetes version is v1.9.0, remove the following check. - // See http://issue.k8s.io/51046 for details. - if updated.Name != created.Name { - return - } - - if updated.Status.DownloadURL != "" { - created = updated - cancel() + if updated.Status.DownloadURL != "" { + return updated.Status.DownloadURL, nil + } } } +} - wait.Until(checkFunc, 25*time.Millisecond, ctx.Done()) - if downloadURLTimeout { - return ErrDownloadRequestDownloadURLTimeout - } - +func download( + ctx context.Context, + downloadURL string, + kind veleroV1api.DownloadTargetKind, + w io.Writer, + insecureSkipTLSVerify bool, + caCertFile string, +) error { var caPool *x509.CertPool if len(caCertFile) > 0 { caCert, err := os.ReadFile(caCertFile) @@ -107,14 +128,13 @@ func Stream(ctx context.Context, kbClient kbclient.Client, namespace, name strin defaultTransport := http.DefaultTransport.(*http.Transport) // same settings as the default transport - // aside from timeout and TLSClientConfig + // aside from TLSClientConfig httpClient := new(http.Client) httpClient.Transport = &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: insecureSkipTLSVerify, //nolint:gosec // This parameter is useful for some scenarios. RootCAs: caPool, }, - IdleConnTimeout: timeout, DialContext: defaultTransport.DialContext, ForceAttemptHTTP2: defaultTransport.ForceAttemptHTTP2, MaxIdleConns: defaultTransport.MaxIdleConns, @@ -123,7 +143,7 @@ func Stream(ctx context.Context, kbClient kbclient.Client, namespace, name strin ExpectContinueTimeout: defaultTransport.ExpectContinueTimeout, } - httpReq, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, created.Status.DownloadURL, nil) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { return err } @@ -153,7 +173,7 @@ func Stream(ctx context.Context, kbClient kbclient.Client, namespace, name strin } reader := resp.Body - if kind != velerov1api.DownloadTargetKindBackupContents { + if kind != veleroV1api.DownloadTargetKindBackupContents { // need to decompress logs gzipReader, err := gzip.NewReader(resp.Body) if err != nil {