diff --git a/pkg/fs/file_test.go b/pkg/fs/file_test.go index bae91bb..fcb6ce7 100644 --- a/pkg/fs/file_test.go +++ b/pkg/fs/file_test.go @@ -3453,6 +3453,317 @@ var writeTests = []struct { true, true, }, + { + "Can not write to symlink to /", + "/existingsymlink", + func() writeArgs { + return writeArgs{strings.NewReader("Hello, world!")} + }, + true, + func(f symFs) error { + if err := f.SymlinkIfPossible("/", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { return nil }, + true, + true, + false, + }, + { + "Can not write to symlink to /mydir", + "/existingsymlink", + func() writeArgs { + return writeArgs{strings.NewReader("Hello, world!")} + }, + true, + func(f symFs) error { + if err := f.Mkdir("/mydir", os.ModePerm); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/mydir", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { return nil }, + true, + true, + false, + }, + { + "Can write empty string to symlink to /test.txt", + "/existingsymlink", + func() writeArgs { + return writeArgs{strings.NewReader("")} + }, + false, + func(f symFs) error { + file, err := f.Create("/test.txt") + if err != nil { + return err + } + + if err := file.Close(); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/test.txt", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { + wantContent := []byte{} + gotContent := make([]byte, len(wantContent)) + + wantLength := len(wantContent) + gotLength, err := f.Read(gotContent) + if err != io.EOF { + return err + } + + if wantLength != gotLength { + return fmt.Errorf("invalid write length, got %v, want %v", gotLength, wantLength) + } + + if wantLength != i { + return fmt.Errorf("invalid write length n, got %v, want %v", i, wantLength) + } + + if string(wantContent) != string(gotContent) { + return fmt.Errorf("invalid write content, got %v, want %v", gotContent, wantContent) + } + + return nil + }, + true, + true, + false, + }, + { + "Can write small amount of data to symlink to /test.txt if seeking afterwards", + "/existingsymlink", + func() writeArgs { + return writeArgs{strings.NewReader("Hello, world!")} + }, + false, + func(f symFs) error { + file, err := f.Create("/test.txt") + if err != nil { + return err + } + + if err := file.Close(); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/test.txt", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return err + } + + wantContent := []byte("Hello, world") + gotContent := make([]byte, len(wantContent)) + + wantLength := len(wantContent) + gotLength, err := f.Read(gotContent) + if err != io.EOF { + return err + } + + if wantLength != gotLength { + return fmt.Errorf("invalid write length, got %v, want %v", gotLength, wantLength) + } + + if wantLength != i { + return fmt.Errorf("invalid write length n, got %v, want %v", i, wantLength) + } + + if string(wantContent) != string(gotContent) { + return fmt.Errorf("invalid write content, got %v, want %v", gotContent, wantContent) + } + + return nil + }, + true, + true, + false, + }, + { + "Can write small amount of data to symlink to /test.txt if not seeking afterwards", + "/existingsymlink", + func() writeArgs { + return writeArgs{strings.NewReader("Hello, world!")} + }, + false, + func(f symFs) error { + file, err := f.Create("/test.txt") + if err != nil { + return err + } + + if err := file.Close(); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/test.txt", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { + wantContent := []byte("") + gotContent := make([]byte, len(wantContent)) + + wantLength := len(wantContent) + gotLength, err := f.Read(gotContent) + if err != io.EOF { + return err + } + + if wantLength != gotLength { + return fmt.Errorf("invalid write length, got %v, want %v", gotLength, wantLength) + } + + if wantLength != i { + return fmt.Errorf("invalid write length n, got %v, want %v", i, wantLength) + } + + if string(wantContent) != string(gotContent) { + return fmt.Errorf("invalid write content, got %v, want %v", gotContent, wantContent) + } + + return nil + }, + true, + true, + false, + }, + { + "Can write 30 MB amount of data to symlink to /test.txt", + "/existingsymlink", + func() writeArgs { + return writeArgs{newDeterministicReader(1000)} + }, + false, + func(f symFs) error { + file, err := f.Create("/test.txt") + if err != nil { + return err + } + + if err := file.Close(); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/test.txt", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return err + } + + wantHash := "HTUi7GuNreHASha4hhl1xwuYk03pyTJ0IJbFLv04UdccT9m_NA2oBFTrnMxJhEu3VMGxDYk_04Th9C0zOj5MyA==" + wantLength := int64(32800768) + + if wantLength != int64(i) { + return fmt.Errorf("invalid write length n, got %v, want %v", i, wantLength) + } + + hasher := sha512.New() + gotLength, err := io.Copy(hasher, f) + if err != nil { + return err + } + gotHash := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) + + if gotLength != wantLength { + return fmt.Errorf("invalid read length, got %v, want %v", gotLength, wantLength) + } + + if gotHash != wantHash { + return fmt.Errorf("invalid read hash, got %v, want %v", gotHash, wantHash) + } + + return nil + }, + true, + true, + false, + }, + { + "Can write 300 MB amount of data to symlink to /test.txt", + "/existingsymlink", + func() writeArgs { + return writeArgs{newDeterministicReader(10000)} + }, + false, + func(f symFs) error { + file, err := f.Create("/test.txt") + if err != nil { + return err + } + + if err := file.Close(); err != nil { + return err + } + + if err := f.SymlinkIfPossible("/test.txt", "/existingsymlink"); err != nil { + return err + } + + return nil + }, + func(f afero.File, i int) error { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return err + } + + wantHash := "3NXGfwSdGiFZjd-sdIcx4xrUnsOPOb4LeDBYGZFVPoRyMqGdqTEHsTbk1Ow3Vn-wIdFqaO8Zj6eXhYvWBakkuQ==" + wantLength := int64(327712768) + + if wantLength != int64(i) { + return fmt.Errorf("invalid write length n, got %v, want %v", i, wantLength) + } + + hasher := sha512.New() + gotLength, err := io.Copy(hasher, f) + if err != nil { + return err + } + gotHash := base64.URLEncoding.EncodeToString(hasher.Sum(nil)) + + if gotLength != wantLength { + return fmt.Errorf("invalid read length, got %v, want %v", gotLength, wantLength) + } + + if gotHash != wantHash { + return fmt.Errorf("invalid read hash, got %v, want %v", gotHash, wantHash) + } + + return nil + }, + true, + true, + true, + }, } func TestFile_Write(t *testing.T) {