summaryrefslogtreecommitdiff
path: root/services/worker/internal/pool/pool.go
blob: 0576636fc3f8bc2a64f9ef4c9b2a511d6bbb2d0f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
package pool

import (
	"context"
	"log/slog"
	"runtime/debug"
	"sync"
)

type WorkFunction func(workContext context.Context)

type WorkerPool struct {
	concurrencyLimit int
	semaphoreChannel chan struct{}
	waitGroup        sync.WaitGroup
	logger           *slog.Logger
}

func NewWorkerPool(concurrencyLimit int, logger *slog.Logger) *WorkerPool {
	return &WorkerPool{
		concurrencyLimit: concurrencyLimit,
		semaphoreChannel: make(chan struct{}, concurrencyLimit),
		logger:           logger,
	}
}

func (workerPool *WorkerPool) Submit(workContext context.Context, workFunction WorkFunction) bool {
	select {
	case workerPool.semaphoreChannel <- struct{}{}:
		workerPool.waitGroup.Add(1)

		go func() {
			defer workerPool.waitGroup.Done()
			defer func() { <-workerPool.semaphoreChannel }()
			defer func() {
				recoveredPanic := recover()

				if recoveredPanic != nil {
					workerPool.logger.Error(
						"worker panic recovered",
						"panic_value", recoveredPanic,
						"stack_trace", string(debug.Stack()),
					)
				}
			}()

			workFunction(workContext)
		}()

		return true
	case <-workContext.Done():
		return false
	}
}

func (workerPool *WorkerPool) Wait() {
	workerPool.waitGroup.Wait()
}

func (workerPool *WorkerPool) ActiveWorkerCount() int {
	return len(workerPool.semaphoreChannel)
}