diff --git a/cmd/stbak/cmd/archive.go b/cmd/stbak/cmd/archive.go index d3e7bee..db010bb 100644 --- a/cmd/stbak/cmd/archive.go +++ b/cmd/stbak/cmd/archive.go @@ -2,7 +2,6 @@ package cmd import ( "archive/tar" - "bufio" "context" "io" "io/fs" @@ -11,7 +10,6 @@ import ( "github.com/pojntfx/stfs/pkg/adapters" "github.com/pojntfx/stfs/pkg/controllers" - "github.com/pojntfx/stfs/pkg/counters" "github.com/pojntfx/stfs/pkg/formatting" "github.com/pojntfx/stfs/pkg/persisters" "github.com/spf13/cobra" @@ -76,89 +74,12 @@ func archive( src string, overwrite bool, ) error { - isRegular := true - stat, err := os.Stat(tape) - if err == nil { - isRegular = stat.Mode().IsRegular() - } else { - if os.IsNotExist(err) { - isRegular = true - } else { - return err - } - } - - var f *os.File - if isRegular { - if overwrite { - f, err = os.OpenFile(tape, os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return err - } - - if err := f.Truncate(0); err != nil { - return err - } - } else { - f, err = os.OpenFile(tape, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return err - } - } - - // No need to go to end manually due to `os.O_APPEND` - } else { - f, err = os.OpenFile(tape, os.O_APPEND|os.O_WRONLY, os.ModeCharDevice) - if err != nil { - return err - } - - if overwrite { - // Go to start of tape - if err := controllers.SeekToRecordOnTape(f, 0); err != nil { - return err - } - } else { - // Go to end of tape - if err := controllers.GoToEndOfTape(f); err != nil { - return err - } - } - } - defer f.Close() - dirty := false - var tw *tar.Writer - var bw *bufio.Writer - var counter *counters.CounterWriter - if isRegular { - tw = tar.NewWriter(f) - } else { - bw = bufio.NewWriterSize(f, controllers.BlockSize*recordSize) - counter = &counters.CounterWriter{Writer: bw, BytesRead: 0} - tw = tar.NewWriter(counter) + tw, isRegular, cleanup, err := openTapeWriter(viper.GetString(tapeFlag)) + if err != nil { + return err } - defer func() { - // Only write the trailer if we wrote to the archive - if dirty { - if err := tw.Close(); err != nil { - panic(err) - } - - if !isRegular { - if controllers.BlockSize*recordSize-counter.BytesRead > 0 { - // Fill the rest of the record with zeros - if _, err := bw.Write(make([]byte, controllers.BlockSize*recordSize-counter.BytesRead)); err != nil { - panic(err) - } - } - - if err := bw.Flush(); err != nil { - panic(err) - } - } - } - }() + defer cleanup(dirty) first := true return filepath.Walk(src, func(path string, info fs.FileInfo, err error) error { diff --git a/cmd/stbak/cmd/remove.go b/cmd/stbak/cmd/remove.go index 8a144d7..8f94c6f 100644 --- a/cmd/stbak/cmd/remove.go +++ b/cmd/stbak/cmd/remove.go @@ -30,71 +30,12 @@ var removeCmd = &cobra.Command{ return err } - isRegular := true - stat, err := os.Stat(viper.GetString(tapeFlag)) - if err == nil { - isRegular = stat.Mode().IsRegular() - } else { - if os.IsNotExist(err) { - isRegular = true - } else { - return err - } - } - - var f *os.File - if isRegular { - f, err = os.OpenFile(viper.GetString(tapeFlag), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return err - } - - // No need to go to end manually due to `os.O_APPEND` - } else { - f, err = os.OpenFile(viper.GetString(tapeFlag), os.O_APPEND|os.O_WRONLY, os.ModeCharDevice) - if err != nil { - return err - } - - // Go to end of tape - if err := controllers.GoToEndOfTape(f); err != nil { - return err - } - } - defer f.Close() - dirty := false - var tw *tar.Writer - var bw *bufio.Writer - var counter *counters.CounterWriter - if isRegular { - tw = tar.NewWriter(f) - } else { - bw = bufio.NewWriterSize(f, controllers.BlockSize*viper.GetInt(recordSizeFlag)) - counter = &counters.CounterWriter{Writer: bw, BytesRead: 0} - tw = tar.NewWriter(counter) + tw, _, cleanup, err := openTapeWriter(viper.GetString(tapeFlag)) + if err != nil { + return err } - defer func() { - // Only write the trailer if we wrote to the archive - if dirty { - if err := tw.Close(); err != nil { - panic(err) - } - - if !isRegular { - if controllers.BlockSize*viper.GetInt(recordSizeFlag)-counter.BytesRead > 0 { - // Fill the rest of the record with zeros - if _, err := bw.Write(make([]byte, controllers.BlockSize*viper.GetInt(recordSizeFlag)-counter.BytesRead)); err != nil { - panic(err) - } - } - - if err := bw.Flush(); err != nil { - panic(err) - } - } - } - }() + defer cleanup(dirty) metadataPersister := persisters.NewMetadataPersister(viper.GetString(metadataFlag)) if err := metadataPersister.Open(); err != nil { @@ -153,6 +94,73 @@ var removeCmd = &cobra.Command{ }, } +func openTapeWriter(path string) (tw *tar.Writer, isRegular bool, cleanup func(dirty bool) error, err error) { + stat, err := os.Stat(path) + if err == nil { + isRegular = stat.Mode().IsRegular() + } else { + if os.IsNotExist(err) { + isRegular = true + } else { + return nil, false, nil, err + } + } + + var f *os.File + if isRegular { + f, err = os.OpenFile(path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + return nil, false, nil, err + } + + // No need to go to end manually due to `os.O_APPEND` + } else { + f, err = os.OpenFile(path, os.O_APPEND|os.O_WRONLY, os.ModeCharDevice) + if err != nil { + return nil, false, nil, err + } + + // Go to end of tape + if err := controllers.GoToEndOfTape(f); err != nil { + return nil, false, nil, err + } + } + + var bw *bufio.Writer + var counter *counters.CounterWriter + if isRegular { + tw = tar.NewWriter(f) + } else { + bw = bufio.NewWriterSize(f, controllers.BlockSize*viper.GetInt(recordSizeFlag)) + counter = &counters.CounterWriter{Writer: bw, BytesRead: 0} + tw = tar.NewWriter(counter) + } + + return tw, isRegular, func(dirty bool) error { + // Only write the trailer if we wrote to the archive + if dirty { + if err := tw.Close(); err != nil { + return err + } + + if !isRegular { + if controllers.BlockSize*viper.GetInt(recordSizeFlag)-counter.BytesRead > 0 { + // Fill the rest of the record with zeros + if _, err := bw.Write(make([]byte, controllers.BlockSize*viper.GetInt(recordSizeFlag)-counter.BytesRead)); err != nil { + return err + } + } + + if err := bw.Flush(); err != nil { + return err + } + } + } + + return f.Close() + }, nil +} + func init() { removeCmd.PersistentFlags().IntP(recordSizeFlag, "e", 20, "Amount of 512-bit blocks per record") removeCmd.PersistentFlags().StringP(nameFlag, "n", "", "Name of the file to remove")