package pool import ( "context" "github.com/stretchr/testify/assert" "log/slog" "sync" "sync/atomic" "testing" "time" ) func newTestLogger() *slog.Logger { return slog.New(slog.DiscardHandler) } func TestWorkerPoolConcurrencyLimit(test *testing.T) { workerPool := NewWorkerPool(2, newTestLogger()) workerContext := context.Background() var peakConcurrency atomic.Int32 var currentConcurrency atomic.Int32 var completedJobs atomic.Int32 for jobIndex := 0; jobIndex < 10; jobIndex++ { workerPool.Submit(workerContext, func(workContext context.Context) { current := currentConcurrency.Add(1) for { peak := peakConcurrency.Load() if current <= peak || peakConcurrency.CompareAndSwap(peak, current) { break } } time.Sleep(10 * time.Millisecond) currentConcurrency.Add(-1) completedJobs.Add(1) }) } workerPool.Wait() assert.Equal(test, int32(10), completedJobs.Load()) assert.LessOrEqual(test, peakConcurrency.Load(), int32(2)) } func TestWorkerPoolPanicRecovery(test *testing.T) { workerPool := NewWorkerPool(2, newTestLogger()) workerContext := context.Background() var completedAfterPanic atomic.Bool workerPool.Submit(workerContext, func(workContext context.Context) { panic("intentional test panic") }) time.Sleep(20 * time.Millisecond) workerPool.Submit(workerContext, func(workContext context.Context) { completedAfterPanic.Store(true) }) workerPool.Wait() assert.True(test, completedAfterPanic.Load()) } func TestWorkerPoolCancelledContext(test *testing.T) { workerPool := NewWorkerPool(1, newTestLogger()) cancelledContext, cancelFunction := context.WithCancel(context.Background()) var blockingMutex sync.Mutex blockingMutex.Lock() workerPool.Submit(context.Background(), func(workContext context.Context) { blockingMutex.Lock() defer blockingMutex.Unlock() }) cancelFunction() submitted := workerPool.Submit(cancelledContext, func(workContext context.Context) {}) assert.False(test, submitted) blockingMutex.Unlock() workerPool.Wait() } func TestWorkerPoolWaitBlocksUntilDone(test *testing.T) { workerPool := NewWorkerPool(4, newTestLogger()) workerContext := context.Background() var counter atomic.Int32 for jobIndex := 0; jobIndex < 20; jobIndex++ { workerPool.Submit(workerContext, func(workContext context.Context) { time.Sleep(5 * time.Millisecond) counter.Add(1) }) } workerPool.Wait() assert.Equal(test, int32(20), counter.Load()) } func TestWorkerPoolActiveWorkerCount(test *testing.T) { workerPool := NewWorkerPool(3, newTestLogger()) assert.Equal(test, 0, workerPool.ActiveWorkerCount()) var releaseMutex sync.Mutex releaseMutex.Lock() workerPool.Submit(context.Background(), func(workContext context.Context) { releaseMutex.Lock() defer releaseMutex.Unlock() }) workerPool.Submit(context.Background(), func(workContext context.Context) { releaseMutex.Lock() defer releaseMutex.Unlock() }) time.Sleep(20 * time.Millisecond) assert.Equal(test, 2, workerPool.ActiveWorkerCount()) releaseMutex.Unlock() workerPool.Wait() }