From 2a67c243fd9624052aa4f0d9f3db9aa0b9aae17a Mon Sep 17 00:00:00 2001 From: Felicitas Pojtinger Date: Fri, 17 Dec 2021 15:30:09 +0100 Subject: [PATCH] feat: Add `io.Writer`-based API to `restore` and `fetch` cmds --- cmd/stbak/cmd/operation_restore.go | 20 ++++++++++++++++++++ cmd/stbak/cmd/recovery_fetch.go | 20 ++++++++++++++++++++ pkg/operations/restore.go | 14 +++++++++++++- pkg/recovery/fetch.go | 13 ++++++------- 4 files changed, 59 insertions(+), 8 deletions(-) diff --git a/cmd/stbak/cmd/operation_restore.go b/cmd/stbak/cmd/operation_restore.go index dd3d9ec..937f7ee 100644 --- a/cmd/stbak/cmd/operation_restore.go +++ b/cmd/stbak/cmd/operation_restore.go @@ -1,6 +1,10 @@ package cmd import ( + "io" + "io/fs" + "os" + "github.com/pojntfx/stfs/internal/keys" "github.com/pojntfx/stfs/internal/logging" "github.com/pojntfx/stfs/internal/persisters" @@ -93,6 +97,22 @@ var operationRestoreCmd = &cobra.Command{ ) return ops.Restore( + func(path string, mode fs.FileMode) (io.WriteCloser, error) { + dstFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, mode) + if err != nil { + return nil, err + } + + if err := dstFile.Truncate(0); err != nil { + return nil, err + } + + return dstFile, nil + }, + func(path string, mode fs.FileMode) error { + return os.MkdirAll(path, mode) + }, + viper.GetString(fromFlag), viper.GetString(toFlag), viper.GetBool(flattenFlag), diff --git a/cmd/stbak/cmd/recovery_fetch.go b/cmd/stbak/cmd/recovery_fetch.go index 0d47261..5e0c785 100644 --- a/cmd/stbak/cmd/recovery_fetch.go +++ b/cmd/stbak/cmd/recovery_fetch.go @@ -1,6 +1,10 @@ package cmd import ( + "io" + "io/fs" + "os" + "github.com/pojntfx/stfs/internal/keys" "github.com/pojntfx/stfs/internal/logging" "github.com/pojntfx/stfs/pkg/config" @@ -80,6 +84,22 @@ var recoveryFetchCmd = &cobra.Command{ Password: viper.GetString(passwordFlag), }, + func(path string, mode fs.FileMode) (io.WriteCloser, error) { + dstFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, mode) + if err != nil { + return nil, err + } + + if err := dstFile.Truncate(0); err != nil { + return nil, err + } + + return dstFile, nil + }, + func(path string, mode fs.FileMode) error { + return os.MkdirAll(path, mode) + }, + viper.GetInt(recordSizeFlag), viper.GetInt(recordFlag), viper.GetInt(blockFlag), diff --git a/pkg/operations/restore.go b/pkg/operations/restore.go index 1e4736b..2dad881 100644 --- a/pkg/operations/restore.go +++ b/pkg/operations/restore.go @@ -4,6 +4,8 @@ import ( "archive/tar" "context" "database/sql" + "io" + "io/fs" "path" "path/filepath" "strings" @@ -13,7 +15,14 @@ import ( "github.com/pojntfx/stfs/pkg/recovery" ) -func (o *Operations) Restore(from string, to string, flatten bool) error { +func (o *Operations) Restore( + getDst func(path string, mode fs.FileMode) (io.WriteCloser, error), + mkdirAll func(path string, mode fs.FileMode) error, + + from string, + to string, + flatten bool, +) error { o.diskOperationLock.Lock() defer o.diskOperationLock.Unlock() @@ -84,6 +93,9 @@ func (o *Operations) Restore(from string, to string, flatten bool) error { o.pipes, o.crypto, + getDst, + mkdirAll, + o.pipes.RecordSize, int(dbhdr.Record), int(dbhdr.Block), diff --git a/pkg/recovery/fetch.go b/pkg/recovery/fetch.go index 22a3825..101346e 100644 --- a/pkg/recovery/fetch.go +++ b/pkg/recovery/fetch.go @@ -4,7 +4,7 @@ import ( "archive/tar" "bufio" "io" - "os" + "io/fs" "path/filepath" "github.com/pojntfx/stfs/internal/compression" @@ -23,6 +23,9 @@ func Fetch( pipes config.PipeConfig, crypto config.CryptoConfig, + getDst func(path string, mode fs.FileMode) (io.WriteCloser, error), + mkdirAll func(path string, mode fs.FileMode) error, + recordSize int, record int, block int, @@ -82,18 +85,14 @@ func Fetch( } if hdr.Typeflag == tar.TypeDir { - return os.MkdirAll(to, hdr.FileInfo().Mode()) + return mkdirAll(to, hdr.FileInfo().Mode()) } - dstFile, err := os.OpenFile(to, os.O_WRONLY|os.O_CREATE, hdr.FileInfo().Mode()) + dstFile, err := getDst(to, hdr.FileInfo().Mode()) if err != nil { return err } - if err := dstFile.Truncate(0); err != nil { - return err - } - // Don't decompress non-regular files if !hdr.FileInfo().Mode().IsRegular() { if _, err := io.Copy(dstFile, tr); err != nil {