diff --git a/cmd/serve_ftp.go b/cmd/serve_ftp.go index 2fa9a38..a6ac5e6 100644 --- a/cmd/serve_ftp.go +++ b/cmd/serve_ftp.go @@ -2,10 +2,14 @@ package cmd import ( "context" + "fmt" "log" + "os" + "path/filepath" "time" ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/pojntfx/stfs/internal/cache" sfs "github.com/pojntfx/stfs/internal/fs" "github.com/pojntfx/stfs/internal/ftp" "github.com/pojntfx/stfs/internal/keys" @@ -14,7 +18,6 @@ import ( "github.com/pojntfx/stfs/pkg/config" "github.com/pojntfx/stfs/pkg/operations" "github.com/pojntfx/stfs/pkg/tape" - "github.com/spf13/afero" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -28,6 +31,10 @@ var serveFTPCmd = &cobra.Command{ return err } + if err := cache.CheckCacheType(viper.GetString(cacheFlag)); err != nil { + return err + } + if err := keys.CheckKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)); err != nil { return err } @@ -113,11 +120,15 @@ var serveFTPCmd = &cobra.Command{ logger.PrintHeader, ) - var fs afero.Fs - if viper.GetBool(cacheFlag) { - fs = afero.NewCacheOnReadFs(afero.NewBasePathFs(stfs, root), afero.NewMemMapFs(), time.Hour) - } else { - fs = afero.NewBasePathFs(stfs, root) + fs, err := cache.Cache( + stfs, + root, + viper.GetString(cacheFlag), + viper.GetDuration(cacheDurationFlag), + viper.GetString(cacheDirFlag), + ) + if err != nil { + return err } srv := ftpserver.NewFtpServer( @@ -145,7 +156,9 @@ func init() { serveFTPCmd.PersistentFlags().StringP(passwordFlag, "p", "", "Password for the private key") serveFTPCmd.PersistentFlags().StringP(recipientFlag, "r", "", "Path to the public key to verify with") serveFTPCmd.PersistentFlags().StringP(laddrFlag, "a", "localhost:1337", "Listen address") - serveFTPCmd.PersistentFlags().BoolP(cacheFlag, "n", true, "Enable in-memory caching") + serveFTPCmd.PersistentFlags().StringP(cacheFlag, "n", config.NoneKey, fmt.Sprintf("Cache to use (default %v, available are %v)", config.NoneKey, cache.KnownCacheTypes)) + serveFTPCmd.PersistentFlags().DurationP(cacheDurationFlag, "u", time.Hour, "Duration until cache is invalidated") + serveFTPCmd.PersistentFlags().StringP(cacheDirFlag, "w", filepath.Join(os.TempDir(), "stfs", "cache"), "Directory to use if dir cache is enabled") viper.AutomaticEnv() diff --git a/cmd/serve_http.go b/cmd/serve_http.go index e0f5930..5b65521 100644 --- a/cmd/serve_http.go +++ b/cmd/serve_http.go @@ -2,10 +2,14 @@ package cmd import ( "context" + "fmt" "log" "net/http" + "os" + "path/filepath" "time" + "github.com/pojntfx/stfs/internal/cache" sfs "github.com/pojntfx/stfs/internal/fs" "github.com/pojntfx/stfs/internal/handlers" "github.com/pojntfx/stfs/internal/keys" @@ -20,8 +24,10 @@ import ( ) const ( - laddrFlag = "laddr" - cacheFlag = "cache" + laddrFlag = "laddr" + cacheFlag = "cache" + cacheDirFlag = "cache-dir" + cacheDurationFlag = "cache-duration" ) var serveHTTPCmd = &cobra.Command{ @@ -33,6 +39,10 @@ var serveHTTPCmd = &cobra.Command{ return err } + if err := cache.CheckCacheType(viper.GetString(cacheFlag)); err != nil { + return err + } + if err := keys.CheckKeyAccessible(viper.GetString(encryptionFlag), viper.GetString(identityFlag)); err != nil { return err } @@ -118,11 +128,15 @@ var serveHTTPCmd = &cobra.Command{ logger.PrintHeader, ) - var fs afero.Fs - if viper.GetBool(cacheFlag) { - fs = afero.NewCacheOnReadFs(afero.NewBasePathFs(stfs, root), afero.NewMemMapFs(), time.Hour) - } else { - fs = afero.NewBasePathFs(stfs, root) + fs, err := cache.Cache( + stfs, + root, + viper.GetString(cacheFlag), + viper.GetDuration(cacheDurationFlag), + viper.GetString(cacheDirFlag), + ) + if err != nil { + return err } log.Println("Listening on", viper.GetString(laddrFlag)) @@ -144,7 +158,9 @@ func init() { serveHTTPCmd.PersistentFlags().StringP(passwordFlag, "p", "", "Password for the private key") serveHTTPCmd.PersistentFlags().StringP(recipientFlag, "r", "", "Path to the public key to verify with") serveHTTPCmd.PersistentFlags().StringP(laddrFlag, "a", "localhost:1337", "Listen address") - serveHTTPCmd.PersistentFlags().BoolP(cacheFlag, "n", true, "Enable in-memory caching") + serveHTTPCmd.PersistentFlags().StringP(cacheFlag, "n", config.NoneKey, fmt.Sprintf("Cache to use (default %v, available are %v)", config.NoneKey, cache.KnownCacheTypes)) + serveHTTPCmd.PersistentFlags().DurationP(cacheDurationFlag, "u", time.Hour, "Duration until cache is invalidated") + serveHTTPCmd.PersistentFlags().StringP(cacheDirFlag, "w", filepath.Join(os.TempDir(), "stfs", "cache"), "Directory to use if dir cache is enabled") viper.AutomaticEnv() diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..edf962e --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,32 @@ +package cache + +import ( + "os" + "time" + + "github.com/pojntfx/stfs/pkg/config" + "github.com/spf13/afero" +) + +func Cache( + base afero.Fs, + root string, + cacheType string, + ttl time.Duration, + cacheDir string, +) (afero.Fs, error) { + switch cacheType { + case CacheTypeMemory: + return afero.NewCacheOnReadFs(afero.NewBasePathFs(base, root), afero.NewMemMapFs(), ttl), nil + case CacheTypeDir: + if err := os.MkdirAll(cacheDir, os.ModePerm); err != nil { + return nil, err + } + + return afero.NewCacheOnReadFs(afero.NewBasePathFs(base, root), afero.NewBasePathFs(afero.NewOsFs(), cacheDir), ttl), nil + case config.NoneKey: + return afero.NewBasePathFs(base, root), nil + default: + return nil, ErrCacheTypeUnsupported + } +} diff --git a/internal/cache/check.go b/internal/cache/check.go new file mode 100644 index 0000000..577d209 --- /dev/null +++ b/internal/cache/check.go @@ -0,0 +1,17 @@ +package cache + +func CheckCacheType(cacheType string) error { + cacheTypeIsKnown := false + + for _, candidate := range KnownCacheTypes { + if cacheType == candidate { + cacheTypeIsKnown = true + } + } + + if !cacheTypeIsKnown { + return ErrCacheTypeUnknown + } + + return nil +} diff --git a/internal/cache/constants.go b/internal/cache/constants.go new file mode 100644 index 0000000..b6a5a91 --- /dev/null +++ b/internal/cache/constants.go @@ -0,0 +1,14 @@ +package cache + +import ( + "github.com/pojntfx/stfs/pkg/config" +) + +const ( + CacheTypeMemory = "memory" + CacheTypeDir = "dir" +) + +var ( + KnownCacheTypes = []string{config.NoneKey, CacheTypeMemory, CacheTypeDir} +) diff --git a/internal/cache/errors.go b/internal/cache/errors.go new file mode 100644 index 0000000..7722600 --- /dev/null +++ b/internal/cache/errors.go @@ -0,0 +1,8 @@ +package cache + +import "errors" + +var ( + ErrCacheTypeUnsupported = errors.New("cache type unsupported") + ErrCacheTypeUnknown = errors.New("cache type unknown") +)