refactor: Decompose key checks

This commit is contained in:
Felix Pojtinger
2021-12-02 23:10:34 +01:00
parent 96afddeb22
commit 7d4214c328
10 changed files with 58 additions and 84 deletions

View File

@@ -25,6 +25,7 @@ import (
"github.com/pojntfx/stfs/pkg/controllers"
"github.com/pojntfx/stfs/pkg/counters"
"github.com/pojntfx/stfs/pkg/formatting"
"github.com/pojntfx/stfs/pkg/noop"
"github.com/pojntfx/stfs/pkg/pax"
"github.com/pojntfx/stfs/pkg/persisters"
"github.com/spf13/cobra"
@@ -58,32 +59,6 @@ var (
errMissingTarHeader = errors.New("tar header is missing")
)
type flusher interface {
io.WriteCloser
Flush() error
}
func nopCloserWriter(w io.Writer) nopCloser {
return nopCloser{w}
}
type nopCloser struct {
io.Writer
}
func (nopCloser) Close() error { return nil }
func nopFlusherWriter(w io.WriteCloser) nopFlusher {
return nopFlusher{w}
}
type nopFlusher struct {
io.WriteCloser
}
func (nopFlusher) Flush() error { return nil }
var archiveCmd = &cobra.Command{
Use: "archive",
Aliases: []string{"arc", "a", "c"},
@@ -97,13 +72,7 @@ var archiveCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil {
return errRecipientNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if viper.GetBool(verboseFlag) {
@@ -417,6 +386,18 @@ func archive(
})
}
func checkKeyAccessible(encryptionFormat string, pathToKey string) error {
if encryptionFormat == encryptionFormatNoneKey {
return nil
}
if _, err := os.Stat(pathToKey); err != nil {
return errRecipientNotAccessible
}
return nil
}
func checkCompressionLevel(compressionLevel string) error {
compressionLevelIsKnown := false
@@ -494,7 +475,7 @@ func encrypt(
return age.Encrypt(dst, recipient)
case encryptionFormatNoneKey:
return nopCloserWriter(dst), nil
return noop.AddClose(dst), nil
default:
return nil, errUnsupportedEncryptionFormat
}
@@ -540,7 +521,7 @@ func compress(
compressionLevel string,
isRegular bool,
recordSize int,
) (flusher, error) {
) (noop.Flusher, error) {
switch compressionFormat {
case compressionFormatGZipKey:
fallthrough
@@ -592,7 +573,7 @@ func compress(
return nil, err
}
return nopFlusherWriter(lz), nil
return noop.AddFlush(lz), nil
case compressionFormatZStandardKey:
l := zstd.SpeedDefault
switch compressionLevel {
@@ -650,9 +631,9 @@ func compress(
return nil, err
}
return nopFlusherWriter(bz), nil
return noop.AddFlush(bz), nil
case compressionFormatNoneKey:
return nopFlusherWriter(nopCloserWriter(dst)), nil
return noop.AddFlush(noop.AddClose(dst)), nil
default:
return nil, errUnsupportedCompressionFormat
}

View File

@@ -32,13 +32,7 @@ var deleteCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil {
return errRecipientNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -4,7 +4,6 @@ import (
"archive/tar"
"context"
"io/ioutil"
"os"
"strings"
"github.com/pojntfx/stfs/pkg/converters"
@@ -26,13 +25,7 @@ var moveCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil {
return errRecipientNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -47,7 +47,7 @@ var recoveryFetchCmd = &cobra.Command{
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -31,13 +31,7 @@ var recoveryIndexCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(identityFlag)); err != nil {
return errIdentityNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -6,7 +6,6 @@ import (
"io"
"io/ioutil"
"math"
"os"
"github.com/pojntfx/stfs/pkg/controllers"
"github.com/pojntfx/stfs/pkg/counters"
@@ -24,13 +23,7 @@ var recoveryQueryCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(identityFlag)); err != nil {
return errIdentityNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -5,7 +5,6 @@ import (
"context"
"database/sql"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
@@ -32,13 +31,7 @@ var restoreCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(identityFlag)); err != nil {
return errIdentityNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

View File

@@ -35,13 +35,7 @@ var updateCmd = &cobra.Command{
return err
}
if viper.GetString(encryptionFlag) != encryptionFormatNoneKey {
if _, err := os.Stat(viper.GetString(recipientFlag)); err != nil {
return errRecipientNotAccessible
}
}
return nil
return checkKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(recipientFlag))
},
RunE: func(cmd *cobra.Command, args []string) error {
if err := viper.BindPFlags(cmd.PersistentFlags()); err != nil {

19
pkg/noop/flusher.go Normal file
View File

@@ -0,0 +1,19 @@
package noop
import "io"
type Flusher interface {
io.WriteCloser
Flush() error
}
type NoOpFlusher struct {
io.WriteCloser
}
func (NoOpFlusher) Flush() error { return nil }
func AddFlush(w io.WriteCloser) NoOpFlusher {
return NoOpFlusher{w}
}

13
pkg/noop/write.go Normal file
View File

@@ -0,0 +1,13 @@
package noop
import "io"
type NoOpCloser struct {
io.Writer
}
func (NoOpCloser) Close() error { return nil }
func AddClose(w io.Writer) NoOpCloser {
return NoOpCloser{w}
}