fix: Normalize behaviour between os.BasePathFs and STFS

This commit is contained in:
Felix Pojtinger
2022-01-11 03:31:22 +01:00
parent 65ea0322ab
commit 2d3ef7f7e1
2 changed files with 138 additions and 11 deletions

View File

@@ -89,6 +89,10 @@ func (f *STFS) Create(name string) (afero.File, error) {
return nil, os.ErrPermission
}
if checkName(name) {
return nil, os.ErrInvalid
}
return f.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
}
@@ -324,6 +328,10 @@ func (f *STFS) Open(name string) (afero.File, error) {
"name": name,
})
if checkName(name) {
return nil, os.ErrInvalid
}
return f.OpenFile(name, os.O_RDONLY, 0)
}
@@ -334,6 +342,10 @@ func (f *STFS) OpenFile(name string, flag int, perm os.FileMode) (afero.File, er
"perm": perm,
})
if checkName(name) {
return nil, os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -431,6 +443,10 @@ func (f *STFS) Remove(name string) error {
return os.ErrPermission
}
if checkName(name) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -446,6 +462,10 @@ func (f *STFS) RemoveAll(path string) error {
return os.ErrPermission
}
if checkName(path) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -462,6 +482,14 @@ func (f *STFS) Rename(oldname, newname string) error {
return os.ErrPermission
}
if checkName(oldname) {
return os.ErrInvalid
}
if checkName(newname) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -535,6 +563,10 @@ func (f *STFS) Chmod(name string, mode os.FileMode) error {
return os.ErrPermission
}
if checkName(name) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -570,6 +602,10 @@ func (f *STFS) Chown(name string, uid, gid int) error {
return os.ErrPermission
}
if checkName(name) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -606,6 +642,10 @@ func (f *STFS) Chtimes(name string, atime time.Time, mtime time.Time) error {
return os.ErrPermission
}
if checkName(name) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -636,6 +676,10 @@ func (f *STFS) lstatIfPossibleWithoutLocking(name string) (os.FileInfo, bool, er
"name": name,
})
if checkName(name) {
return nil, true, os.ErrInvalid
}
hdr, err := inventory.Stat(
f.metadata,
@@ -660,6 +704,10 @@ func (f *STFS) LstatIfPossible(name string) (os.FileInfo, bool, error) {
"name": name,
})
if checkName(name) {
return nil, true, os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -676,6 +724,14 @@ func (f *STFS) SymlinkIfPossible(oldname, newname string) error {
return os.ErrPermission
}
if checkName(oldname) {
return os.ErrInvalid
}
if checkName(newname) {
return os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -687,6 +743,10 @@ func (f *STFS) ReadlinkIfPossible(name string) (string, error) {
"name": name,
})
if checkName(name) {
return "", os.ErrInvalid
}
f.ioLock.Lock()
defer f.ioLock.Unlock()
@@ -697,3 +757,7 @@ func (f *STFS) ReadlinkIfPossible(name string) (string, error) {
return info.Name(), nil
}
func checkName(name string) bool {
return len(name) == 0
}

View File

@@ -613,6 +613,8 @@ func TestSTFS_Name(t *testing.T) {
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
if got := tt.f.Name(); got != tt.want {
t.Errorf("%v.Name() = %v, want %v", t.Name(), got, tt.want)
@@ -643,10 +645,10 @@ var createTests = []struct {
{
"Can not create file ' '",
createArgs{" "},
true,
false,
},
{
"Can not create file ''",
"Can create file ''",
createArgs{""},
true,
},
@@ -660,6 +662,8 @@ var createTests = []struct {
func TestSTFS_Create(t *testing.T) {
for _, tt := range createTests {
tt := tt
runTestForAllFss(t, tt.name, true, func(t *testing.T, fs fsConfig) {
file, err := fs.fs.Create(tt.args.name)
if (err != nil) != tt.wantErr {
@@ -762,6 +766,8 @@ var initializeTests = []struct {
func TestSTFS_Initialize(t *testing.T) {
for _, tt := range initializeTests {
tt := tt
runTestForAllFss(t, tt.name, false, func(t *testing.T, fs fsConfig) {
f, ok := fs.fs.(*STFS)
if !ok {
@@ -778,6 +784,7 @@ func TestSTFS_Initialize(t *testing.T) {
gotRoot, err := f.Initialize(tt.args.rootProposal, tt.args.rootPerm)
if (err != nil) != tt.wantErr {
t.Errorf("%v.Initialize() error = %v, wantErr %v", f.Name(), err, tt.wantErr)
return
}
if gotRoot != tt.wantRoot {
@@ -813,12 +820,12 @@ var mkdirTests = []struct {
true,
},
{
"Can not create directory ' '",
"Can create directory ' '",
mkdirArgs{" ", os.ModePerm},
true,
false,
},
{
"Can not create directory ''",
"Can create directory ''",
mkdirArgs{"", os.ModePerm},
true,
},
@@ -832,6 +839,8 @@ var mkdirTests = []struct {
func TestSTFS_Mkdir(t *testing.T) {
for _, tt := range mkdirTests {
tt := tt
runTestForAllFss(t, tt.name, true, func(t *testing.T, fs fsConfig) {
if err := fs.fs.Mkdir(tt.args.name, tt.args.perm); (err != nil) != tt.wantErr {
t.Errorf("%v.Mkdir() error = %v, wantErr %v", fs.fs.Name(), err, tt.wantErr)
@@ -874,19 +883,19 @@ var mkdirAllTests = []struct {
false,
},
{
"Can not create existing directory /",
"Can create existing directory /",
mkdirAllArgs{"/", os.ModePerm},
true,
false,
},
{
"Can not create directory ' '",
"Can create directory ' '",
mkdirAllArgs{" ", os.ModePerm},
true,
false,
},
{
"Can not create directory ''",
"Can create directory ''",
mkdirAllArgs{"", os.ModePerm},
true,
false,
},
{
"Can create /nonexistent/test.txt",
@@ -912,6 +921,8 @@ var mkdirAllTests = []struct {
func TestSTFS_MkdirAll(t *testing.T) {
for _, tt := range mkdirAllTests {
tt := tt
runTestForAllFss(t, tt.name, true, func(t *testing.T, fs fsConfig) {
if err := fs.fs.MkdirAll(tt.args.name, tt.args.perm); (err != nil) != tt.wantErr {
t.Errorf("%v.MkdirAll() error = %v, wantErr %v", fs.fs.Name(), err, tt.wantErr)
@@ -932,3 +943,55 @@ func TestSTFS_MkdirAll(t *testing.T) {
})
}
}
func TestSTFS_Open(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
wantErr bool
prepare func(afero.Fs) error
check func(afero.File) error
}{
{
"Can open /",
args{"/"},
false,
func(f afero.Fs) error { return nil },
func(f afero.File) error { return nil },
},
{
"Can not open /test.txt without creating it",
args{"/test.txt"},
true,
func(f afero.Fs) error { return nil },
func(f afero.File) error { return nil },
},
}
for _, tt := range tests {
tt := tt
runTestForAllFss(t, tt.name, true, func(t *testing.T, fs fsConfig) {
if err := tt.prepare(fs.fs); err != nil {
t.Error(err)
return
}
got, err := fs.fs.Open(tt.args.name)
if err == nil && tt.wantErr {
t.Fatalf("%v.Open() error = %v, wantErr %v", fs.fs.Name(), err, tt.wantErr)
return
}
if err := tt.check(got); err != nil {
t.Error(err)
return
}
})
}
}