diff --git a/src/extract.go b/src/extract.go index 652a8dc..27d028f 100644 --- a/src/extract.go +++ b/src/extract.go @@ -18,13 +18,33 @@ import ( var ErrArchiveTooLarge = errors.New("archive too large") -func ExtractTar(reader io.Reader) (*Manifest, error) { - // If the tar stream is itself compressed, both the outer and the inner bounds checks - // are load-bearing. - boundedReader := ReadAtMost(reader, int64(config.Limits.MaxSiteSize.Bytes()), +func boundArchiveStream(reader io.Reader) io.Reader { + return ReadAtMost(reader, int64(config.Limits.MaxSiteSize.Bytes()), fmt.Errorf("%w: %s limit exceeded", ErrArchiveTooLarge, config.Limits.MaxSiteSize.HR())) +} - archive := tar.NewReader(boundedReader) +func ExtractGzip(reader io.Reader, next func(io.Reader) (*Manifest, error)) (*Manifest, error) { + stream, err := gzip.NewReader(reader) + if err != nil { + return nil, err + } + defer stream.Close() + + return next(boundArchiveStream(stream)) +} + +func ExtractZstd(reader io.Reader, next func(io.Reader) (*Manifest, error)) (*Manifest, error) { + stream, err := zstd.NewReader(reader) + if err != nil { + return nil, err + } + defer stream.Close() + + return next(boundArchiveStream(stream)) +} + +func ExtractTar(reader io.Reader) (*Manifest, error) { + archive := tar.NewReader(reader) manifest := Manifest{ Contents: map[string]*Entry{ @@ -84,28 +104,6 @@ func ExtractTar(reader io.Reader) (*Manifest, error) { return &manifest, nil } -func ExtractTarGzip(reader io.Reader) (*Manifest, error) { - stream, err := gzip.NewReader(reader) - if err != nil { - return nil, err - } - defer stream.Close() - - // stream length is limited in `ExtractTar` - return ExtractTar(stream) -} - -func ExtractTarZstd(reader io.Reader) (*Manifest, error) { - stream, err := zstd.NewReader(reader) - if err != nil { - return nil, err - } - defer stream.Close() - - // stream length is limited in `ExtractTar` - return ExtractTar(stream) -} - func ExtractZip(reader io.Reader) (*Manifest, error) { data, err := io.ReadAll(reader) if err != nil { diff --git a/src/update.go b/src/update.go index f33d7ac..713970a 100644 --- a/src/update.go +++ b/src/update.go @@ -125,10 +125,10 @@ func UpdateFromArchive( manifest, err = ExtractTar(reader) // yellow? case "application/x-tar+gzip": logc.Printf(ctx, "update %s: (tar.gz)", webRoot) - manifest, err = ExtractTarGzip(reader) // definitely yellow. + manifest, err = ExtractGzip(reader, ExtractTar) // definitely yellow. case "application/x-tar+zstd": logc.Printf(ctx, "update %s: (tar.zst)", webRoot) - manifest, err = ExtractTarZstd(reader) + manifest, err = ExtractZstd(reader, ExtractTar) case "application/zip": logc.Printf(ctx, "update %s: (zip)", webRoot) manifest, err = ExtractZip(reader)