From bf24f2dcc529772dbb943043eca319441108796c Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Tue, 20 Mar 2018 19:24:18 +0100 Subject: [PATCH] Implement better Parallel (#174) * Implement better Parallel --- common/async.go | 67 ++++++++++++++++---- common/async_test.go | 145 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 11 deletions(-) create mode 100644 common/async_test.go diff --git a/common/async.go b/common/async.go index 1d302c344..23d1a42b3 100644 --- a/common/async.go +++ b/common/async.go @@ -1,15 +1,60 @@ package common -import "sync" +// val: the value returned after task execution. +// err: the error returned during task completion. +// abort: tells Parallel to return, whether or not all tasks have completed. +type Task func(i int) (val interface{}, err error, abort bool) -func Parallel(tasks ...func()) { - var wg sync.WaitGroup - wg.Add(len(tasks)) - for _, task := range tasks { - go func(task func()) { - task() - wg.Done() - }(task) - } - wg.Wait() +type TaskResult struct { + Value interface{} + Error error + Panic interface{} +} + +type TaskResultCh <-chan TaskResult + +// Run tasks in parallel, with ability to abort early. +// NOTE: Do not implement quit features here. Instead, provide convenient +// concurrent quit-like primitives, passed implicitly via Task closures. (e.g. +// it's not Parallel's concern how you quit/abort your tasks). +func Parallel(tasks ...Task) []TaskResultCh { + var taskResultChz = make([]TaskResultCh, len(tasks)) // To return. + var taskDoneCh = make(chan bool, len(tasks)) // A "wait group" channel, early abort if any true received. + + // Start all tasks in parallel in separate goroutines. + // When the task is complete, it will appear in the + // respective taskResultCh (associated by task index). + for i, task := range tasks { + var taskResultCh = make(chan TaskResult, 1) // Capacity for 1 result. + taskResultChz[i] = taskResultCh + go func(i int, task Task, taskResultCh chan TaskResult) { + // Recovery + defer func() { + if pnk := recover(); pnk != nil { + taskResultCh <- TaskResult{nil, nil, pnk} + taskDoneCh <- false + } + }() + // Run the task. + var val, err, abort = task(i) + // Send val/err to taskResultCh. + // NOTE: Below this line, nothing must panic/ + taskResultCh <- TaskResult{val, err, nil} + // Decrement waitgroup. + taskDoneCh <- abort + }(i, task, taskResultCh) + } + + // Wait until all tasks are done, or until abort. + for i := 0; i < len(tasks); i++ { + abort := <-taskDoneCh + if abort { + break + } + } + + // Caller can use this however they want. + // TODO: implement convenience functions to + // make sense of this structure safely. + return taskResultChz } diff --git a/common/async_test.go b/common/async_test.go new file mode 100644 index 000000000..1d6b0e7b0 --- /dev/null +++ b/common/async_test.go @@ -0,0 +1,145 @@ +package common + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParallel(t *testing.T) { + + // Create tasks. + var counter = new(int32) + var tasks = make([]Task, 100*1000) + for i := 0; i < len(tasks); i++ { + tasks[i] = func(i int) (res interface{}, err error, abort bool) { + atomic.AddInt32(counter, 1) + return -1 * i, nil, false + } + } + + // Run in parallel. + var taskResultChz = Parallel(tasks...) + + // Verify. + assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") + var failedTasks int + for i := 0; i < len(tasks); i++ { + select { + case taskResult := <-taskResultChz[i]: + if taskResult.Error != nil { + assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) + failedTasks += 1 + } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { + failedTasks += 1 + } else { + // Good! + } + default: + failedTasks += 1 + } + } + assert.Equal(t, failedTasks, 0, "No task should have failed") + +} + +func TestParallelAbort(t *testing.T) { + + var flow1 = make(chan struct{}, 1) + var flow2 = make(chan struct{}, 1) + var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. + var flow4 = make(chan struct{}, 1) + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 0) + flow1 <- struct{}{} + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 1) + flow2 <- <-flow1 + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 2) + flow3 <- <-flow2 + return 2, nil, true + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 3) + <-flow4 + return 3, nil, false + }, + } + + // Run in parallel. + var taskResultChz = Parallel(tasks...) + + // Verify task #3. + // Initially taskResultCh[3] sends nothing since flow4 didn't send. + waitTimeout(t, taskResultChz[3], "Task #3") + + // Now let the last task (#3) complete after abort. + flow4 <- <-flow3 + + // Verify task #0, #1, #2. + waitFor(t, taskResultChz[0], "Task #0", 0, nil, nil) + waitFor(t, taskResultChz[1], "Task #1", 1, errors.New("some error"), nil) + waitFor(t, taskResultChz[2], "Task #2", 2, nil, nil) +} + +func TestParallelRecover(t *testing.T) { + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + panic(2) + }, + } + + // Run in parallel. + var taskResultChz = Parallel(tasks...) + + // Verify task #0, #1, #2. + waitFor(t, taskResultChz[0], "Task #0", 0, nil, nil) + waitFor(t, taskResultChz[1], "Task #1", 1, errors.New("some error"), nil) + waitFor(t, taskResultChz[2], "Task #2", nil, nil, 2) +} + +// Wait for result +func waitFor(t *testing.T, taskResultCh TaskResultCh, taskName string, val interface{}, err error, pnk interface{}) { + select { + case taskResult, ok := <-taskResultCh: + assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) + assert.Equal(t, val, taskResult.Value, taskName) + assert.Equal(t, err, taskResult.Error, taskName) + assert.Equal(t, pnk, taskResult.Panic, taskName) + default: + assert.Fail(t, "Failed to receive result for %v", taskName) + } +} + +// Wait for timeout (no result) +func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { + select { + case _, ok := <-taskResultCh: + if !ok { + assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) + } else { + assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) + } + case <-time.After(1 * time.Second): // TODO use deterministic time? + // Good! + } +}