feat: Add io.Writer-based API to restore and fetch cmds

This commit is contained in:
Felicitas Pojtinger
2021-12-17 15:30:09 +01:00
parent e11a3bffcc
commit 2a67c243fd
4 changed files with 59 additions and 8 deletions

View File

@@ -1,6 +1,10 @@
package cmd package cmd
import ( import (
"io"
"io/fs"
"os"
"github.com/pojntfx/stfs/internal/keys" "github.com/pojntfx/stfs/internal/keys"
"github.com/pojntfx/stfs/internal/logging" "github.com/pojntfx/stfs/internal/logging"
"github.com/pojntfx/stfs/internal/persisters" "github.com/pojntfx/stfs/internal/persisters"
@@ -93,6 +97,22 @@ var operationRestoreCmd = &cobra.Command{
) )
return ops.Restore( return ops.Restore(
func(path string, mode fs.FileMode) (io.WriteCloser, error) {
dstFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, mode)
if err != nil {
return nil, err
}
if err := dstFile.Truncate(0); err != nil {
return nil, err
}
return dstFile, nil
},
func(path string, mode fs.FileMode) error {
return os.MkdirAll(path, mode)
},
viper.GetString(fromFlag), viper.GetString(fromFlag),
viper.GetString(toFlag), viper.GetString(toFlag),
viper.GetBool(flattenFlag), viper.GetBool(flattenFlag),

View File

@@ -1,6 +1,10 @@
package cmd package cmd
import ( import (
"io"
"io/fs"
"os"
"github.com/pojntfx/stfs/internal/keys" "github.com/pojntfx/stfs/internal/keys"
"github.com/pojntfx/stfs/internal/logging" "github.com/pojntfx/stfs/internal/logging"
"github.com/pojntfx/stfs/pkg/config" "github.com/pojntfx/stfs/pkg/config"
@@ -80,6 +84,22 @@ var recoveryFetchCmd = &cobra.Command{
Password: viper.GetString(passwordFlag), Password: viper.GetString(passwordFlag),
}, },
func(path string, mode fs.FileMode) (io.WriteCloser, error) {
dstFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, mode)
if err != nil {
return nil, err
}
if err := dstFile.Truncate(0); err != nil {
return nil, err
}
return dstFile, nil
},
func(path string, mode fs.FileMode) error {
return os.MkdirAll(path, mode)
},
viper.GetInt(recordSizeFlag), viper.GetInt(recordSizeFlag),
viper.GetInt(recordFlag), viper.GetInt(recordFlag),
viper.GetInt(blockFlag), viper.GetInt(blockFlag),

View File

@@ -4,6 +4,8 @@ import (
"archive/tar" "archive/tar"
"context" "context"
"database/sql" "database/sql"
"io"
"io/fs"
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -13,7 +15,14 @@ import (
"github.com/pojntfx/stfs/pkg/recovery" "github.com/pojntfx/stfs/pkg/recovery"
) )
func (o *Operations) Restore(from string, to string, flatten bool) error { func (o *Operations) Restore(
getDst func(path string, mode fs.FileMode) (io.WriteCloser, error),
mkdirAll func(path string, mode fs.FileMode) error,
from string,
to string,
flatten bool,
) error {
o.diskOperationLock.Lock() o.diskOperationLock.Lock()
defer o.diskOperationLock.Unlock() defer o.diskOperationLock.Unlock()
@@ -84,6 +93,9 @@ func (o *Operations) Restore(from string, to string, flatten bool) error {
o.pipes, o.pipes,
o.crypto, o.crypto,
getDst,
mkdirAll,
o.pipes.RecordSize, o.pipes.RecordSize,
int(dbhdr.Record), int(dbhdr.Record),
int(dbhdr.Block), int(dbhdr.Block),

View File

@@ -4,7 +4,7 @@ import (
"archive/tar" "archive/tar"
"bufio" "bufio"
"io" "io"
"os" "io/fs"
"path/filepath" "path/filepath"
"github.com/pojntfx/stfs/internal/compression" "github.com/pojntfx/stfs/internal/compression"
@@ -23,6 +23,9 @@ func Fetch(
pipes config.PipeConfig, pipes config.PipeConfig,
crypto config.CryptoConfig, crypto config.CryptoConfig,
getDst func(path string, mode fs.FileMode) (io.WriteCloser, error),
mkdirAll func(path string, mode fs.FileMode) error,
recordSize int, recordSize int,
record int, record int,
block int, block int,
@@ -82,18 +85,14 @@ func Fetch(
} }
if hdr.Typeflag == tar.TypeDir { if hdr.Typeflag == tar.TypeDir {
return os.MkdirAll(to, hdr.FileInfo().Mode()) return mkdirAll(to, hdr.FileInfo().Mode())
} }
dstFile, err := os.OpenFile(to, os.O_WRONLY|os.O_CREATE, hdr.FileInfo().Mode()) dstFile, err := getDst(to, hdr.FileInfo().Mode())
if err != nil { if err != nil {
return err return err
} }
if err := dstFile.Truncate(0); err != nil {
return err
}
// Don't decompress non-regular files // Don't decompress non-regular files
if !hdr.FileInfo().Mode().IsRegular() { if !hdr.FileInfo().Mode().IsRegular() {
if _, err := io.Copy(dstFile, tr); err != nil { if _, err := io.Copy(dstFile, tr); err != nil {