diff --git a/cmd/stbak/cmd/archive.go b/cmd/stbak/cmd/archive.go index 73e2f15..ea8728c 100644 --- a/cmd/stbak/cmd/archive.go +++ b/cmd/stbak/cmd/archive.go @@ -25,6 +25,7 @@ import ( "github.com/pojntfx/stfs/pkg/controllers" "github.com/pojntfx/stfs/pkg/counters" "github.com/pojntfx/stfs/pkg/formatting" + "github.com/pojntfx/stfs/pkg/noop" "github.com/pojntfx/stfs/pkg/pax" "github.com/pojntfx/stfs/pkg/persisters" "github.com/spf13/cobra" @@ -58,32 +59,6 @@ var ( errMissingTarHeader = errors.New("tar header is missing") ) -type flusher interface { - io.WriteCloser - - Flush() error -} - -func nopCloserWriter(w io.Writer) nopCloser { - return nopCloser{w} -} - -type nopCloser struct { - io.Writer -} - -func (nopCloser) Close() error { return nil } - -func nopFlusherWriter(w io.WriteCloser) nopFlusher { - return nopFlusher{w} -} - -type nopFlusher struct { - io.WriteCloser -} - -func (nopFlusher) Flush() error { return nil } - var archiveCmd = &cobra.Command{ Use: "archive", Aliases: []string{"arc", "a", "c"}, @@ -97,13 +72,7 @@ var archiveCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil { - return errRecipientNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if viper.GetBool(verboseFlag) { @@ -417,6 +386,18 @@ func archive( }) } +func checkKeyAccessible(encryptionFormat string, pathToKey string) error { + if encryptionFormat == encryptionFormatNoneKey { + return nil + } + + if _, err := os.Stat(pathToKey); err != nil { + return errRecipientNotAccessible + } + + return nil +} + func checkCompressionLevel(compressionLevel string) error { compressionLevelIsKnown := false @@ -494,7 +475,7 @@ func encrypt( return age.Encrypt(dst, recipient) case encryptionFormatNoneKey: - return nopCloserWriter(dst), nil + return noop.AddClose(dst), nil default: return nil, errUnsupportedEncryptionFormat } @@ -540,7 +521,7 @@ func compress( compressionLevel string, isRegular bool, recordSize int, -) (flusher, error) { +) (noop.Flusher, error) { switch compressionFormat { case compressionFormatGZipKey: fallthrough @@ -592,7 +573,7 @@ func compress( return nil, err } - return nopFlusherWriter(lz), nil + return noop.AddFlush(lz), nil case compressionFormatZStandardKey: l := zstd.SpeedDefault switch compressionLevel { @@ -650,9 +631,9 @@ func compress( return nil, err } - return nopFlusherWriter(bz), nil + return noop.AddFlush(bz), nil case compressionFormatNoneKey: - return nopFlusherWriter(nopCloserWriter(dst)), nil + return noop.AddFlush(noop.AddClose(dst)), nil default: return nil, errUnsupportedCompressionFormat } diff --git a/cmd/stbak/cmd/delete.go b/cmd/stbak/cmd/delete.go index d0de0b0..c445469 100644 --- a/cmd/stbak/cmd/delete.go +++ b/cmd/stbak/cmd/delete.go @@ -32,13 +32,7 @@ var deleteCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil { - return errRecipientNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/move.go b/cmd/stbak/cmd/move.go index b635366..6e9a0a6 100644 --- a/cmd/stbak/cmd/move.go +++ b/cmd/stbak/cmd/move.go @@ -4,7 +4,6 @@ import ( "archive/tar" "context" "io/ioutil" - "os" "strings" "github.com/pojntfx/stfs/pkg/converters" @@ -26,13 +25,7 @@ var moveCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil { - return errRecipientNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/recovery_fetch.go b/cmd/stbak/cmd/recovery_fetch.go index 3c81cec..1da2de5 100644 --- a/cmd/stbak/cmd/recovery_fetch.go +++ b/cmd/stbak/cmd/recovery_fetch.go @@ -47,7 +47,7 @@ var recoveryFetchCmd = &cobra.Command{ } } - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/recovery_index.go b/cmd/stbak/cmd/recovery_index.go index 1f11a6b..b5f1de5 100644 --- a/cmd/stbak/cmd/recovery_index.go +++ b/cmd/stbak/cmd/recovery_index.go @@ -31,13 +31,7 @@ var recoveryIndexCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(identityFlag)); err != nil { - return errIdentityNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/recovery_query.go b/cmd/stbak/cmd/recovery_query.go index 729ade7..a35950f 100644 --- a/cmd/stbak/cmd/recovery_query.go +++ b/cmd/stbak/cmd/recovery_query.go @@ -6,7 +6,6 @@ import ( "io" "io/ioutil" "math" - "os" "github.com/pojntfx/stfs/pkg/controllers" "github.com/pojntfx/stfs/pkg/counters" @@ -24,13 +23,7 @@ var recoveryQueryCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(identityFlag)); err != nil { - return errIdentityNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/restore.go b/cmd/stbak/cmd/restore.go index bb6f2f4..0ab74a1 100644 --- a/cmd/stbak/cmd/restore.go +++ b/cmd/stbak/cmd/restore.go @@ -5,7 +5,6 @@ import ( "context" "database/sql" "io/ioutil" - "os" "path" "path/filepath" "strings" @@ -32,13 +31,7 @@ var restoreCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(identityFlag)); err != nil { - return errIdentityNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/cmd/stbak/cmd/update.go b/cmd/stbak/cmd/update.go index 1d1df9f..65f314d 100644 --- a/cmd/stbak/cmd/update.go +++ b/cmd/stbak/cmd/update.go @@ -35,13 +35,7 @@ var updateCmd = &cobra.Command{ return err } - if viper.GetString(encryptionFlag) != encryptionFormatNoneKey { - if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil { - return errRecipientNotAccessible - } - } - - return nil + return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag)) }, RunE: func(cmd *cobra.Command, args []string) error { if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil { diff --git a/pkg/noop/flusher.go b/pkg/noop/flusher.go new file mode 100644 index 0000000..a319898 --- /dev/null +++ b/pkg/noop/flusher.go @@ -0,0 +1,19 @@ +package noop + +import "io" + +type Flusher interface { + io.WriteCloser + + Flush() error +} + +type NoOpFlusher struct { + io.WriteCloser +} + +func (NoOpFlusher) Flush() error { return nil } + +func AddFlush(w io.WriteCloser) NoOpFlusher { + return NoOpFlusher{w} +} diff --git a/pkg/noop/write.go b/pkg/noop/write.go new file mode 100644 index 0000000..2b1a4ac --- /dev/null +++ b/pkg/noop/write.go @@ -0,0 +1,13 @@ +package noop + +import "io" + +type NoOpCloser struct { + io.Writer +} + +func (NoOpCloser) Close() error { return nil } + +func AddClose(w io.Writer) NoOpCloser { + return NoOpCloser{w} +}