Skip to content

Commit

Permalink
Add parallel mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
totemcaf committed Aug 9, 2022
1 parent 39662c4 commit 8ccd1c1
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 0 deletions.
30 changes: 30 additions & 0 deletions syncs/parallel_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package syncs

import (
"context"
"time"

"github.com/totemcaf/gollections/slices"
)

// ParallelMap applies the given mapper to each element of the given slice in parallel.
// The results and errors are returned in the same order as the slice.
// If the context expires before all the mappers are finished, the remaining mappers are cancelled.
// If a mapper fails, the error is returned in the errs slice.
// If a mapper succeeds, the result is returned in the results slice.
//
// Example:
// values := []int{1, 2, 3, 4, 5}
// results, errs := ParallelMap(values, time.Second, func(ctx context.Context, v V) (T, error) {
// return v * v, nil
// })
func ParallelMap[V, T any](values []V, maxWait time.Duration, mapper func(context.Context, V) (T, error)) ([]T, []error) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(maxWait))
defer cancel()

return WaitAll(ctx, slices.Map(values, func(v V) Waitable[T] {
return func(ctx context.Context) (T, error) {
return mapper(ctx, v)
}
})...)
}
34 changes: 34 additions & 0 deletions syncs/parallel_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package syncs

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func Test_ParallelMap_maps_values(t *testing.T) {
values := []int{1, 2, 3, 4, 5}
results, errs := ParallelMap(values, time.Second, func(_ context.Context, v int) (int, error) {
return v * v, nil
})

assert.Equal(t, []int{1, 4, 9, 16, 25}, results)
assert.Equal(t, []error{nil, nil, nil, nil, nil}, errs)

}

func Test_ParallelMap_cancels_remaining_waitables_if_context_expires(t *testing.T) {
values := []int{1, 2, 3}

results, errs := ParallelMap(values, time.Millisecond*100, func(ctx context.Context, v int) (int, error) {
if Sleep(ctx, time.Second) {
return 0, ctx.Err()
}
return v * v, nil
})

assert.Equal(t, []int{0, 0, 0}, results)
assert.Equal(t, []error{context.DeadlineExceeded, context.DeadlineExceeded, context.DeadlineExceeded}, errs)
}
17 changes: 17 additions & 0 deletions syncs/sleep.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package syncs

import (
"context"
"time"
)

// Sleep waits for the given duration.
// Returns true if the context was canceled.
func Sleep(ctx context.Context, delay time.Duration) bool {
select {
case <-time.After(delay):
return false
case <-ctx.Done():
return true
}
}
49 changes: 49 additions & 0 deletions syncs/wait_all.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package syncs

import (
"context"
"sync"
"time"
)

// Waitable is a function that can be waited on.
type Waitable[T any] func(ctx context.Context) (T, error)

// WaitAll waits for all Waitable to complete successfully or to fail.
// Cancelling the context will cause all Waitable to fail.
//
// The results and errors are returned in the same order as the Waitable.
// If a Waitable fails, the error is returned in the errs slice.
// If a Waitable succeeds, the result is returned in the results slice.
// If a Waitable is cancelled, the result is returned in the results slice and the error is returned in the errs slice.
//
// ctx (context) can be used to cancel the WaitAll. If nil is provided, a context with a deadline of 1 second is used.
// Remember to cancel the context when you're done with it.
//
// Example:
// ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second * 3))
// defer cancel()
// results, errs := WaitAll(ctx, funcs...)
//
func WaitAll[T any](ctx context.Context, waitables ...Waitable[T]) ([]T, []error) {
var wg sync.WaitGroup

if ctx == nil {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()
}

results := make([]T, len(waitables))
errs := make([]error, len(waitables))

for i, waitable := range waitables {
wg.Add(1)
go func(i int, waitable Waitable[T]) {
defer wg.Done()
results[i], errs[i] = waitable(ctx)
}(i, waitable)
}
wg.Wait()
return results, errs
}
55 changes: 55 additions & 0 deletions syncs/wait_all_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package syncs

import (
"context"
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/totemcaf/gollections/slices"
)

func Test_WaitAll_waits_for_all_the_waitables(t *testing.T) {
funcs := slices.Map([]int{1, 2, 3, 4, 5}, func(i int) Waitable[int] {
return func(ctx context.Context) (int, error) {
fmt.Println("waiting for", i)
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
fmt.Println("finished for", i)
return i, nil
}
})

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()

results, errs := WaitAll(ctx, funcs...)

assert.Equal(t, []int{1, 2, 3, 4, 5}, results)
assert.Equal(t, []error{nil, nil, nil, nil, nil}, errs)
}

func Test_WaitAll_report_failing_waitables(t *testing.T) {
funcs := slices.Map([]int{1, 2, 3, 4, 5}, func(i int) Waitable[int] {
return func(ctx context.Context) (int, error) {
fmt.Println("waiting for", i)
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
fmt.Println("finished for", i)

if i == 3 || i == 4 {
return 0, fmt.Errorf("error for %d", i)
}

return i, nil
}
})

ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()

results, errs := WaitAll(ctx, funcs...)

assert.Equal(t, []int{1, 2, 0, 0, 5}, results)
assert.Equal(t, []error{nil, nil, fmt.Errorf("error for 3"), fmt.Errorf("error for 4"), nil}, errs)
}

0 comments on commit 8ccd1c1

Please sign in to comment.