fix: cancel filesystem traversal when listing request cancelled

For large directories, the treewalk can take longer than the
client request timeout. If the client times out the request
then we need to stop walking the filesystem and just return
the context error.

This should prevent the gateway from consuming system resources
uneccessarily after an incoming request is terminated.
This commit is contained in:
Ben McClelland
2024-07-13 09:34:34 -07:00
parent 5d33c7bde5
commit a8adb471fe
4 changed files with 64 additions and 10 deletions

View File

@@ -2013,7 +2013,7 @@ func (p *Posix) CopyObject(ctx context.Context, input *s3.CopyObjectInput) (*s3.
}, nil
}
func (p *Posix) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
func (p *Posix) ListObjects(ctx context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
if input.Bucket == nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName)
}
@@ -2044,7 +2044,7 @@ func (p *Posix) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3.
}
fileSystem := os.DirFS(bucket)
results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys,
results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys,
p.fileToObj(bucket), []string{metaTmpDir})
if err != nil {
return nil, fmt.Errorf("walk %v: %w", bucket, err)
@@ -2126,7 +2126,7 @@ func (p *Posix) fileToObj(bucket string) backend.GetObjFunc {
}
}
func (p *Posix) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
func (p *Posix) ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
if input.Bucket == nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName)
}
@@ -2165,7 +2165,7 @@ func (p *Posix) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) (
}
fileSystem := os.DirFS(bucket)
results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys,
results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys,
p.fileToObj(bucket), []string{metaTmpDir})
if err != nil {
return nil, fmt.Errorf("walk %v: %w", bucket, err)

View File

@@ -714,7 +714,7 @@ func (s *ScoutFS) getXattrTags(bucket, object string) (map[string]string, error)
return tags, nil
}
func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
func (s *ScoutFS) ListObjects(ctx context.Context, input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
if input.Bucket == nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName)
}
@@ -745,7 +745,7 @@ func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s
}
fileSystem := os.DirFS(bucket)
results, err := backend.Walk(fileSystem, prefix, delim, marker, maxkeys,
results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, maxkeys,
s.fileToObj(bucket), []string{metaTmpDir})
if err != nil {
return nil, fmt.Errorf("walk %v: %w", bucket, err)
@@ -764,7 +764,7 @@ func (s *ScoutFS) ListObjects(_ context.Context, input *s3.ListObjectsInput) (*s
}, nil
}
func (s *ScoutFS) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
func (s *ScoutFS) ListObjectsV2(ctx context.Context, input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
if input.Bucket == nil {
return nil, s3err.GetAPIError(s3err.ErrInvalidBucketName)
}
@@ -795,7 +795,7 @@ func (s *ScoutFS) ListObjectsV2(_ context.Context, input *s3.ListObjectsV2Input)
}
fileSystem := os.DirFS(bucket)
results, err := backend.Walk(fileSystem, prefix, delim, marker, int32(maxkeys),
results, err := backend.Walk(ctx, fileSystem, prefix, delim, marker, int32(maxkeys),
s.fileToObj(bucket), []string{metaTmpDir})
if err != nil {
return nil, fmt.Errorf("walk %v: %w", bucket, err)

View File

@@ -15,6 +15,7 @@
package backend
import (
"context"
"errors"
"fmt"
"io/fs"
@@ -38,7 +39,7 @@ var ErrSkipObj = errors.New("skip this object")
// Walk walks the supplied fs.FS and returns results compatible with list
// objects responses
func Walk(fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj GetObjFunc, skipdirs []string) (WalkResults, error) {
func Walk(ctx context.Context, fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj GetObjFunc, skipdirs []string) (WalkResults, error) {
cpmap := make(map[string]struct{})
var objects []types.Object
@@ -55,6 +56,9 @@ func Walk(fileSystem fs.FS, prefix, delimiter, marker string, max int32, getObj
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
// Ignore the root directory
if path == "." {
return nil

View File

@@ -15,12 +15,15 @@
package backend_test
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"io/fs"
"sync"
"testing"
"testing/fstest"
"time"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/versity/versitygw/backend"
@@ -108,7 +111,7 @@ func TestWalk(t *testing.T) {
}
for _, tt := range tests {
res, err := backend.Walk(tt.fsys, "", "/", "", 1000, tt.getobj, []string{})
res, err := backend.Walk(context.Background(), tt.fsys, "", "/", "", 1000, tt.getobj, []string{})
if err != nil {
t.Fatalf("walk: %v", err)
}
@@ -204,3 +207,50 @@ func printObjects(list []types.Object) string {
}
return res + "]"
}
type slowFS struct {
fstest.MapFS
}
const (
readDirPause = 100 * time.Millisecond
// walkTimeOut should be less than the tree traversal time
// which is the readdirPause time * the number of directories
walkTimeOut = 500 * time.Millisecond
)
func (s *slowFS) ReadDir(name string) ([]fs.DirEntry, error) {
time.Sleep(readDirPause)
return s.MapFS.ReadDir(name)
}
func TestWalkStop(t *testing.T) {
s := &slowFS{MapFS: fstest.MapFS{
"/a/b/c/d/e/f/g/h/i/g/k/l/m/n": &fstest.MapFile{},
}}
ctx, cancel := context.WithTimeout(context.Background(), walkTimeOut)
defer cancel()
var err error
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_, err = backend.Walk(ctx, s, "", "/", "", 1000,
func(path string, d fs.DirEntry) (types.Object, error) {
return types.Object{}, nil
}, []string{})
}()
select {
case <-time.After(1 * time.Second):
t.Fatalf("walk is not terminated in time")
case <-ctx.Done():
}
wg.Wait()
if err != ctx.Err() {
t.Fatalf("unexpected error: %v", err)
}
}