diff --git a/pkg/fs/filesystem.go b/pkg/fs/filesystem.go index 9778022..d325fd9 100644 --- a/pkg/fs/filesystem.go +++ b/pkg/fs/filesystem.go @@ -451,52 +451,66 @@ func (f *STFS) OpenFile(name string, flag int, perm os.FileMode) (afero.File, er ) if err != nil { if err == sql.ErrNoRows { - if !f.readOnly && flag&os.O_CREATE != 0 && flag&os.O_EXCL == 0 { - if _, err := inventory.Stat( - f.metadata, + hdr, err = inventory.Stat( + f.metadata, - filepath.Dir(name), - false, + name, + true, - f.onHeader, - ); err != nil { - if err == sql.ErrNoRows { + f.onHeader, + ) + if err != nil { + if err == sql.ErrNoRows { + if !f.readOnly && flag&os.O_CREATE != 0 && flag&os.O_EXCL == 0 { + if _, err := inventory.Stat( + f.metadata, + + filepath.Dir(name), + false, + + f.onHeader, + ); err != nil { + if err == sql.ErrNoRows { + return nil, os.ErrNotExist + } + + return nil, err + } + + if target, err := inventory.Stat( + f.metadata, + + name, + true, + + f.onHeader, + ); err == nil { + if target.Typeflag == tar.TypeDir { + return nil, config.ErrIsDirectory + } + } + + if err := f.mknodeWithoutLocking(false, name, perm, false, "", false); err != nil { + return nil, err + } + + hdr, err = inventory.Stat( + f.metadata, + + name, + false, + + f.onHeader, + ) + if err != nil { + return nil, err + } + } else { return nil, os.ErrNotExist } - + } else { return nil, err } - - if target, err := inventory.Stat( - f.metadata, - - name, - true, - - f.onHeader, - ); err == nil { - if target.Typeflag == tar.TypeDir { - return nil, config.ErrIsDirectory - } - } - - if err := f.mknodeWithoutLocking(false, name, perm, false, "", false); err != nil { - return nil, err - } - - hdr, err = inventory.Stat( - f.metadata, - - name, - false, - - f.onHeader, - ) - if err != nil { - return nil, err - } - } else { - return nil, os.ErrNotExist } } else { return nil, err @@ -717,10 +731,24 @@ func (f *STFS) Stat(name string) (os.FileInfo, error) { ) if err != nil { if err == sql.ErrNoRows { - return nil, os.ErrNotExist - } + hdr, err = inventory.Stat( + f.metadata, - return nil, err + name, + true, + + f.onHeader, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, os.ErrNotExist + } + + return nil, err + } + } else { + return nil, err + } } return NewFileInfoFromTarHeader(hdr, f.log), nil diff --git a/pkg/fs/filesystem_test.go b/pkg/fs/filesystem_test.go index 2b1280d..41d2ed5 100644 --- a/pkg/fs/filesystem_test.go +++ b/pkg/fs/filesystem_test.go @@ -1267,42 +1267,42 @@ var openTests = []struct { name string args openArgs wantErr bool - prepare func(afero.Fs) error + prepare func(symFs) error check func(afero.File) error }{ { "Can open /", openArgs{"/"}, false, - func(f afero.Fs) error { return nil }, + func(f symFs) error { return nil }, func(f afero.File) error { return nil }, }, { "Can not open ' '", openArgs{" "}, true, - func(f afero.Fs) error { return nil }, + func(f symFs) error { return nil }, func(f afero.File) error { return nil }, }, { "Can open ''", openArgs{""}, false, - func(f afero.Fs) error { return nil }, + func(f symFs) error { return nil }, func(f afero.File) error { return nil }, }, { "Can not open /test.txt without creating it", openArgs{"/test.txt"}, true, - func(f afero.Fs) error { return nil }, + func(f symFs) error { return nil }, func(f afero.File) error { return nil }, }, { "Can open /test.txt after creating it", openArgs{"/test.txt"}, false, - func(f afero.Fs) error { + func(f symFs) error { if _, err := f.Create("/test.txt"); err != nil { return err } @@ -1324,14 +1324,36 @@ var openTests = []struct { "Can not open /mydir/test.txt without creating it", openArgs{"/mydir/test.txt"}, true, - func(f afero.Fs) error { return nil }, + func(f symFs) error { return nil }, func(f afero.File) error { return nil }, }, + { + "Can open /mydir after creating it", + openArgs{"/mydir"}, + false, + func(f symFs) error { + if err := f.Mkdir("/mydir", os.ModePerm); err != nil { + return err + } + + return nil + }, + func(f afero.File) error { + want := "/mydir" + got := f.Name() + + if want != got { + return fmt.Errorf("invalid name, got %v, want %v", got, want) + } + + return nil + }, + }, { "Can open /mydir/test.txt after creating it", openArgs{"/mydir/test.txt"}, false, - func(f afero.Fs) error { + func(f symFs) error { if err := f.Mkdir("/mydir", os.ModePerm); err != nil { return err } @@ -1353,6 +1375,50 @@ var openTests = []struct { return nil }, }, + { + "Can open symlink to root", + openArgs{"/existingsymlink"}, + false, + func(sf symFs) error { + if err := sf.SymlinkIfPossible("/", "/existingsymlink"); err != nil { + return nil + } + + return nil + }, + func(f afero.File) error { return nil }, + }, + // FIXME: Since we can't differentiate between broken and non-broken symlinks, this does not work yet + // { + // "Can not broken symlink to /brokensymlink", + // openArgs{"/brokensymlink"}, + // true, + // func(sf symFs) error { + // if err := sf.SymlinkIfPossible("/test.txt", "/brokensymlink"); err != nil { + // return nil + // } + + // return nil + // }, + // func(f afero.File) error { return nil }, + // }, + { + "Can open symlink /existingsymlink to directory", + openArgs{"/existingsymlink"}, + false, + func(sf symFs) error { + if err := sf.Mkdir("/mydir", os.ModePerm); err != nil { + return err + } + + if err := sf.SymlinkIfPossible("/mydir", "/existingsymlink"); err != nil { + return nil + } + + return nil + }, + func(f afero.File) error { return nil }, + }, } func TestSTFS_Open(t *testing.T) { @@ -1360,21 +1426,26 @@ func TestSTFS_Open(t *testing.T) { tt := tt runTestForAllFss(t, tt.name, true, true, true, func(t *testing.T, fs fsConfig) { - if err := tt.prepare(fs.fs); err != nil { - t.Errorf("%v prepare() error = %v", fs.fs.Name(), err) + symFs, ok := fs.fs.(symFs) + if !ok { + return + } + + if err := tt.prepare(symFs); err != nil { + t.Errorf("%v prepare() error = %v", symFs.Name(), err) return } - got, err := fs.fs.Open(tt.args.name) + got, err := symFs.Open(tt.args.name) if (err != nil) != tt.wantErr { - t.Errorf("%v.Open() error = %v, wantErr %v", fs.fs.Name(), err, tt.wantErr) + t.Errorf("%v.Open() error = %v, wantErr %v", symFs.Name(), err, tt.wantErr) return } if err := tt.check(got); err != nil { - t.Errorf("%v check() error = %v", fs.fs.Name(), err) + t.Errorf("%v check() error = %v", symFs.Name(), err) return }