diff --git a/backend/posix/posix.go b/backend/posix/posix.go index 9b5ab5d..834d314 100644 --- a/backend/posix/posix.go +++ b/backend/posix/posix.go @@ -2013,7 +2013,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3. }, nil } -func (p *Posix) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { +func (p *Posix) ListObjects(ctx context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { if input.Bucket == nil { return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName) } @@ -2044,7 +2044,7 @@ func (p *Posix) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3. } fileSystem := os.DirFS(bucket) - results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys, + results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys, p.fileToObj(bucket), []string{metaTmpDir}) if err != nil { return nil, fmt.Errorf("walk %v: %w", bucket, err) @@ -2126,7 +2126,7 @@ func (p *Posix) fileToObj(bucket string) backend.GetObjFunc { } } -func (p *Posix) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { +func (p *Posix) ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { if input.Bucket == nil { return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName) } @@ -2165,7 +2165,7 @@ func (p *Posix) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) ( } fileSystem := os.DirFS(bucket) - results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys, + results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys, p.fileToObj(bucket), []string{metaTmpDir}) if err != nil { return nil, fmt.Errorf("walk %v: %w", bucket, err) diff --git a/backend/scoutfs/scoutfs.go b/backend/scoutfs/scoutfs.go index 606475f..5ef1979 100644 --- a/backend/scoutfs/scoutfs.go +++ b/backend/scoutfs/scoutfs.go @@ -714,7 +714,7 @@ func (s *ScoutFS) getXattrTags(bucket, object string) (map[string]string, error) return tags, nil } -func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { +func (s *ScoutFS) ListObjects(ctx context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { if input.Bucket == nil { return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName) } @@ -745,7 +745,7 @@ func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s } fileSystem := os.DirFS(bucket) - results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys, + results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys, s.fileToObj(bucket), []string{metaTmpDir}) if err != nil { return nil, fmt.Errorf("walk %v: %w", bucket, err) @@ -764,7 +764,7 @@ func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s }, nil } -func (s *ScoutFS) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { +func (s *ScoutFS) ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { if input.Bucket == nil { return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName) } @@ -795,7 +795,7 @@ func (s *ScoutFS) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) } fileSystem := os.DirFS(bucket) - results, err := backend.Walk(fileSystem, prefix, delim, marker, int32(maxkeys), + results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, int32(maxkeys), s.fileToObj(bucket), []string{metaTmpDir}) if err != nil { return nil, fmt.Errorf("walk %v: %w", bucket, err) diff --git a/backend/walk.go b/backend/walk.go index d34c223..157522a 100644 --- a/backend/walk.go +++ b/backend/walk.go @@ -15,6 +15,7 @@ package backend import ( + "context" "errors" "fmt" "io/fs" @@ -38,7 +39,7 @@ var ErrSkipObj = errors.New("skip this object") // Walk walks the supplied fs.FS and returns results compatible with list // objects responses -func Walk(fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj GetObjFunc, skipdirs []string) (WalkResults, error) { +func Walk(ctx context.Context, fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj GetObjFunc, skipdirs []string) (WalkResults, error) { cpmap := make(map[string]struct{}) var objects []types.Object @@ -55,6 +56,9 @@ func Walk(fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj if err != nil { return err } + if ctx.Err() != nil { + return ctx.Err() + } // Ignore the root directory if path == "." { return nil diff --git a/backend/walk_test.go b/backend/walk_test.go index a7b4b82..19f0315 100644 --- a/backend/walk_test.go +++ b/backend/walk_test.go @@ -15,12 +15,15 @@ package backend_test import ( + "context" "crypto/md5" "encoding/hex" "fmt" "io/fs" + "sync" "testing" "testing/fstest" + "time" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/versity/versitygw/backend" @@ -108,7 +111,7 @@ func TestWalk(t *testing.T) { } for _, tt := range tests { - res, err := backend.Walk(tt.fsys, "", "/", "", 1000, tt.getobj, []string{}) + res, err := backend.Walk(context.Background(), tt.fsys, "", "/", "", 1000, tt.getobj, []string{}) if err != nil { t.Fatalf("walk: %v", err) } @@ -204,3 +207,50 @@ func printObjects(list []types.Object) string { } return res + "]" } + +type slowFS struct { + fstest.MapFS +} + +const ( + readDirPause = 100 * time.Millisecond + + // walkTimeOut should be less than the tree traversal time + // which is the readdirPause time * the number of directories + walkTimeOut = 500 * time.Millisecond +) + +func (s *slowFS) ReadDir(name string) ([]fs.DirEntry, error) { + time.Sleep(readDirPause) + return s.MapFS.ReadDir(name) +} + +func TestWalkStop(t *testing.T) { + s := &slowFS{MapFS: fstest.MapFS{ + "/a/b/c/d/e/f/g/h/i/g/k/l/m/n": &fstest.MapFile{}, + }} + + ctx, cancel := context.WithTimeout(context.Background(), walkTimeOut) + defer cancel() + + var err error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, err = backend.Walk(ctx, s, "", "/", "", 1000, + func(path string, d fs.DirEntry) (types.Object, error) { + return types.Object{}, nil + }, []string{}) + }() + + select { + case <-time.After(1 * time.Second): + t.Fatalf("walk is not terminated in time") + case <-ctx.Done(): + } + wg.Wait() + if err != ctx.Err() { + t.Fatalf("unexpected error: %v", err) + } +}