-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package syncs | ||
|
||
import ( | ||
"context" | ||
"time" | ||
|
||
"github.com/totemcaf/gollections/slices" | ||
) | ||
|
||
// ParallelMap applies the given mapper to each element of the given slice in parallel. | ||
// The results and errors are returned in the same order as the slice. | ||
// If the context expires before all the mappers are finished, the remaining mappers are cancelled. | ||
// If a mapper fails, the error is returned in the errs slice. | ||
// If a mapper succeeds, the result is returned in the results slice. | ||
// | ||
// Example: | ||
// values := []int{1, 2, 3, 4, 5} | ||
// results, errs := ParallelMap(values, time.Second, func(ctx context.Context, v V) (T, error) { | ||
// return v * v, nil | ||
// }) | ||
func ParallelMap[V, T any](values []V, maxWait time.Duration, mapper func(context.Context, V) (T, error)) ([]T, []error) { | ||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(maxWait)) | ||
defer cancel() | ||
|
||
return WaitAll(ctx, slices.Map(values, func(v V) Waitable[T] { | ||
return func(ctx context.Context) (T, error) { | ||
return mapper(ctx, v) | ||
} | ||
})...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package syncs | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func Test_ParallelMap_maps_values(t *testing.T) { | ||
values := []int{1, 2, 3, 4, 5} | ||
results, errs := ParallelMap(values, time.Second, func(_ context.Context, v int) (int, error) { | ||
return v * v, nil | ||
}) | ||
|
||
assert.Equal(t, []int{1, 4, 9, 16, 25}, results) | ||
assert.Equal(t, []error{nil, nil, nil, nil, nil}, errs) | ||
|
||
} | ||
|
||
func Test_ParallelMap_cancels_remaining_waitables_if_context_expires(t *testing.T) { | ||
values := []int{1, 2, 3} | ||
|
||
results, errs := ParallelMap(values, time.Millisecond*100, func(ctx context.Context, v int) (int, error) { | ||
if Sleep(ctx, time.Second) { | ||
return 0, ctx.Err() | ||
} | ||
return v * v, nil | ||
}) | ||
|
||
assert.Equal(t, []int{0, 0, 0}, results) | ||
assert.Equal(t, []error{context.DeadlineExceeded, context.DeadlineExceeded, context.DeadlineExceeded}, errs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package syncs | ||
|
||
import ( | ||
"context" | ||
"time" | ||
) | ||
|
||
// Sleep waits for the given duration. | ||
// Returns true if the context was canceled. | ||
func Sleep(ctx context.Context, delay time.Duration) bool { | ||
select { | ||
case <-time.After(delay): | ||
return false | ||
case <-ctx.Done(): | ||
return true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package syncs | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"time" | ||
) | ||
|
||
// Waitable is a function that can be waited on. | ||
type Waitable[T any] func(ctx context.Context) (T, error) | ||
|
||
// WaitAll waits for all Waitable to complete successfully or to fail. | ||
// Cancelling the context will cause all Waitable to fail. | ||
// | ||
// The results and errors are returned in the same order as the Waitable. | ||
// If a Waitable fails, the error is returned in the errs slice. | ||
// If a Waitable succeeds, the result is returned in the results slice. | ||
// If a Waitable is cancelled, the result is returned in the results slice and the error is returned in the errs slice. | ||
// | ||
// ctx (context) can be used to cancel the WaitAll. If nil is provided, a context with a deadline of 1 second is used. | ||
// Remember to cancel the context when you're done with it. | ||
// | ||
// Example: | ||
// ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second * 3)) | ||
// defer cancel() | ||
// results, errs := WaitAll(ctx, funcs...) | ||
// | ||
func WaitAll[T any](ctx context.Context, waitables ...Waitable[T]) ([]T, []error) { | ||
var wg sync.WaitGroup | ||
|
||
if ctx == nil { | ||
var cancel context.CancelFunc | ||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Second)) | ||
defer cancel() | ||
} | ||
|
||
results := make([]T, len(waitables)) | ||
errs := make([]error, len(waitables)) | ||
|
||
for i, waitable := range waitables { | ||
wg.Add(1) | ||
go func(i int, waitable Waitable[T]) { | ||
defer wg.Done() | ||
results[i], errs[i] = waitable(ctx) | ||
}(i, waitable) | ||
} | ||
wg.Wait() | ||
return results, errs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package syncs | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"math/rand" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/totemcaf/gollections/slices" | ||
) | ||
|
||
func Test_WaitAll_waits_for_all_the_waitables(t *testing.T) { | ||
funcs := slices.Map([]int{1, 2, 3, 4, 5}, func(i int) Waitable[int] { | ||
return func(ctx context.Context) (int, error) { | ||
fmt.Println("waiting for", i) | ||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) | ||
fmt.Println("finished for", i) | ||
return i, nil | ||
} | ||
}) | ||
|
||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) | ||
defer cancel() | ||
|
||
results, errs := WaitAll(ctx, funcs...) | ||
|
||
assert.Equal(t, []int{1, 2, 3, 4, 5}, results) | ||
assert.Equal(t, []error{nil, nil, nil, nil, nil}, errs) | ||
} | ||
|
||
func Test_WaitAll_report_failing_waitables(t *testing.T) { | ||
funcs := slices.Map([]int{1, 2, 3, 4, 5}, func(i int) Waitable[int] { | ||
return func(ctx context.Context) (int, error) { | ||
fmt.Println("waiting for", i) | ||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) | ||
fmt.Println("finished for", i) | ||
|
||
if i == 3 || i == 4 { | ||
return 0, fmt.Errorf("error for %d", i) | ||
} | ||
|
||
return i, nil | ||
} | ||
}) | ||
|
||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) | ||
defer cancel() | ||
|
||
results, errs := WaitAll(ctx, funcs...) | ||
|
||
assert.Equal(t, []int{1, 2, 0, 0, 5}, results) | ||
assert.Equal(t, []error{nil, nil, fmt.Errorf("error for 3"), fmt.Errorf("error for 4"), nil}, errs) | ||
} |