diff --git a/copy.go b/copy.go index 158cf2e..8dab89b 100644 --- a/copy.go +++ b/copy.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "hash" + "io" "os" "path" "sync" @@ -12,7 +13,6 @@ import ( "syscall" "time" - "github.com/abc950309/acp/mmap" mapset "github.com/deckarep/golang-set/v2" "github.com/hashicorp/go-multierror" sha256 "github.com/minio/sha256-simd" @@ -203,21 +203,18 @@ func (c *Copyer) write(ctx context.Context, job *writeJob, ch chan<- *baseJob, c readErr = c.streamCopy(ctx, chans, job.src, &cntr.bytes) } -func (c *Copyer) streamCopy(ctx context.Context, dsts []chan []byte, src *mmap.ReaderAt, bytes *int64) error { - if src.Len() == 0 { - return nil - } - +func (c *Copyer) streamCopy(ctx context.Context, dsts []chan []byte, src io.ReadCloser, bytes *int64) error { for idx := int64(0); ; idx += batchSize { - buf, err := src.Slice(idx, batchSize) + buf := make([]byte, batchSize) + + n, err := io.ReadFull(src, buf) if err != nil { return fmt.Errorf("slice mmap fail, %w", err) } - copyed := make([]byte, len(buf)) - copy(copyed, buf) + buf = buf[:n] for _, ch := range dsts { - ch <- copyed + ch <- buf } nr := len(buf) diff --git a/job.go b/job.go index 71d73a6..0110197 100644 --- a/job.go +++ b/job.go @@ -2,11 +2,10 @@ package acp import ( "encoding/hex" + "io" "io/fs" "sync" "time" - - "github.com/abc950309/acp/mmap" ) type jobStatus uint8 @@ -106,11 +105,11 @@ func (j *baseJob) report() *Job { type writeJob struct { *baseJob - src *mmap.ReaderAt + src io.ReadCloser ch chan struct{} } -func newWriteJob(job *baseJob, src *mmap.ReaderAt, needWait bool) *writeJob { +func newWriteJob(job *baseJob, src io.ReadCloser, needWait bool) *writeJob { j := &writeJob{ baseJob: job, src: src, diff --git a/mmap/mmap_reader.go b/mmap/mmap_reader.go new file mode 100644 index 0000000..13be2ab --- /dev/null +++ b/mmap/mmap_reader.go @@ -0,0 +1,16 @@ +package mmap + +type Reader struct { + *ReaderAt + index int64 +} + +func NewReader(readerAt *ReaderAt) *Reader { + return &Reader{ReaderAt: readerAt} +} + +func (r *Reader) Read(buf []byte) (n int, err error) { + n, err = r.ReadAt(buf, r.index) + r.index += int64(n) + return +} diff --git a/prepare.go b/prepare.go index df3a035..c3b6790 100644 --- a/prepare.go +++ b/prepare.go @@ -3,6 +3,8 @@ package acp import ( "context" "fmt" + "io" + "os" "sync" "github.com/abc950309/acp/mmap" @@ -39,10 +41,36 @@ func (c *Copyer) prepare(ctx context.Context, indexed <-chan *baseJob) <-chan *w job.setStatus(jobStatusPreparing) - file, err := mmap.Open(job.path) + file, err := func(path string) (io.ReadCloser, error) { + if c.fromDevice.linear { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open src file fail, %w", err) + } + + fileInfo, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("get src file stat fail, %w", err) + } + if fileInfo.Size() == 0 { + return nil, fmt.Errorf("get src file, size is zero") + } + + return file, nil + } + + readerAt, err := mmap.Open(path) + if err != nil { + return nil, fmt.Errorf("open src file by mmap fail, %w", err) + } + if readerAt.Len() == 0 { + return nil, fmt.Errorf("get src file by mmap, size is zero") + } + + return mmap.NewReader(readerAt), nil + }(job.path) if err != nil { - c.reportError(job.path, "", fmt.Errorf("open src file fail, %w", err)) - return + c.reportError(job.path, "", err) } wj := newWriteJob(job, file, c.fromDevice.linear)