Skip to content

Commit

Permalink
removed the struct scoped context and moved the context as an argumen…
Browse files Browse the repository at this point in the history
…t to support timeout in close and wait
  • Loading branch information
JustinMason committed Jul 2, 2024
1 parent 523fc2c commit 3195072
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 29 deletions.
50 changes: 27 additions & 23 deletions pkg/jobpool/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,59 @@ import (
)

type JobPool[T any] struct {
jobFunc func(T)
wg sync.WaitGroup
jobChan chan *T
closed *atomic.Bool
ctx context.Context
jobFunc func(T)
wg *sync.WaitGroup
jobChan chan *T
closed *atomic.Bool
closeOnce sync.Once
}

func NewJobPool[T any](ctx context.Context, jobFunc func(T), concurrency int) *JobPool[T] {
func NewJobPool[T any](jobFunc func(T), concurrency int) *JobPool[T] {
job := &JobPool[T]{
jobFunc: jobFunc,
wg: sync.WaitGroup{},
wg: &sync.WaitGroup{},
jobChan: make(chan *T, concurrency*10),
closed: &atomic.Bool{},
ctx: ctx,
}

for i := 0; i < concurrency; i++ {
go func() {
for j := range job.jobChan {
job.jobFunc(*j)
job.wg.Done()
}
}()
go job.processJobs()
}

return job
}

func (j *JobPool[T]) Wait() {
func (j *JobPool[T]) processJobs() {
for job := range j.jobChan {
j.jobFunc(*job)
j.wg.Done()
}
}

func (j *JobPool[T]) Wait(ctx context.Context) {
j.wg.Wait()
}

func (j *JobPool[T]) Process(jobFunc *T) error {
j.wg.Add(1)

if j.closed.Load() {
return errors.New("job pool is closed")
}

j.wg.Add(1)

j.jobChan <- jobFunc

return nil
}

func (j *JobPool[T]) Close() {
j.closed.Store(true)
close(j.jobChan)
for len(j.jobChan) > 0 {
<-j.jobChan
}
j.wg.Wait()
func (j *JobPool[T]) Close(ctx context.Context) {
j.closeOnce.Do(func() {
j.closed.Store(true)
close(j.jobChan)
for len(j.jobChan) > 0 {
<-j.jobChan
}
j.Wait(ctx)
})
}
30 changes: 24 additions & 6 deletions pkg/jobpool/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *ProcessorMock) ProcessResults(job testJob) {

func TestProcessJobPool(t *testing.T) {
processorMock := new(ProcessorMock)
processJobPool := NewJobPool(context.Background(), processorMock.ProcessResults, 2)
processJobPool := NewJobPool(processorMock.ProcessResults, 2)

testCases := []string{"test1", "test2", "test3"}
processJobs := []*testJob{}
Expand All @@ -44,7 +44,7 @@ func TestProcessJobPool(t *testing.T) {
processJobs = append(processJobs, process)
}

processJobPool.Wait()
processJobPool.Wait(context.Background())

for range processJobs {
processorMock.AssertCalled(t, "ProcessResults", mock.Anything)
Expand All @@ -61,7 +61,7 @@ func TestProcessJobPoolWithTimeout(t *testing.T) {

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
processJobPool := NewJobPool(ctx, processorMock.ProcessResultsTimeout, 2)
processJobPool := NewJobPool(processorMock.ProcessResultsTimeout, 2)

testCases := []string{"test1"}

Expand All @@ -71,7 +71,7 @@ func TestProcessJobPoolWithTimeout(t *testing.T) {
processJobPool.Process(processJob)
}

processJobPool.Wait()
processJobPool.Wait(ctx)
// Check if the context was timed out
err := ctx.Err()
assert.Error(t, err, "Expected an error but got none")
Expand All @@ -83,12 +83,30 @@ func TestProcessJobPoolWithTimeout(t *testing.T) {

func TestProcessJobPoolWithClose(t *testing.T) {
processorMock := NewProcessorMock(time.Duration(200 * time.Millisecond))
processJobPool := NewJobPool(context.Background(), processorMock.ProcessResultsTimeout, 2)
processJobPool := NewJobPool(processorMock.ProcessResultsTimeout, 2)

processJob := &testJob{}
processorMock.On("ProcessResultsTimeout", mock.Anything).Return()
processJobPool.Process(processJob)
processJobPool.Close()
processJobPool.Close(context.Background())

err := processJobPool.Process(processJob)
if err == nil {
t.Error("Expected error when adding job after close, got nil")
}

}

func TestProcessJobPoolWithCloseTimeout(t *testing.T) {
processorMock := NewProcessorMock(time.Duration(200 * time.Millisecond))
processJobPool := NewJobPool(processorMock.ProcessResultsTimeout, 2)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()

processJob := &testJob{}
processorMock.On("ProcessResultsTimeout", mock.Anything).Return()
processJobPool.Process(processJob)
processJobPool.Close(ctx)

err := processJobPool.Process(processJob)
if err == nil {
Expand Down

0 comments on commit 3195072

Please sign in to comment.