123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- package goroutine
- import (
- "fmt"
- "math/rand"
- "sync"
- "sync/atomic"
- "time"
- )
- type LongRunningTaskManager struct {
- maxConcurrent int
- workerNum int32
- taskChan chan func() error
- exitChan chan struct{}
- checkChan chan struct{}
- wg sync.WaitGroup
- }
- func NewTaskManager(maxConcurrent int) *LongRunningTaskManager {
- if maxConcurrent <= 0 {
- maxConcurrent = 1
- }
- tm := &LongRunningTaskManager{
- taskChan: make(chan func() error, 500),
- maxConcurrent: maxConcurrent,
- exitChan: make(chan struct{}),
- checkChan: make(chan struct{}, 1),
- }
- go tm.taskParallelChecker()
- return tm
- }
- func (tm *LongRunningTaskManager) notifyCheck() {
- select {
- case tm.checkChan <- struct{}{}:
- default:
- }
- }
- func (tm *LongRunningTaskManager) taskParallelChecker() {
- checkTaskParallel := func() {
- // 检查是否需要怎加协程
- // 检查任务数量,单协程一秒处理20个,如果现有协程一秒内处理不完,就增加新协程
- // 协程数量大于最大并发数就不用加了
- taskNum := len(tm.taskChan)
- if taskNum <= 0 {
- return
- }
- wn := int(atomic.LoadInt32(&tm.workerNum))
- if wn >= tm.maxConcurrent {
- return
- }
- if wn > 0 && taskNum/20 < wn {
- return
- }
- // 增加协程
- tm.wg.Add(1)
- atomic.AddInt32(&tm.workerNum, 1)
- go func() {
- defer tm.wg.Done()
- defer func() {
- if err := recover(); err != nil {
- fmt.Println(err)
- }
- atomic.AddInt32(&tm.workerNum, -1)
- tm.notifyCheck()
- }()
- tm.worker()
- }()
- }
- var lastCheck int64
- ticker := time.NewTicker(time.Second)
- for {
- select {
- case <-tm.checkChan:
- ct := time.Now().Unix()
- if ct < lastCheck {
- lastCheck = ct
- continue
- }
- if ct-lastCheck < 1 {
- continue
- }
- checkTaskParallel()
- case tm := <-ticker.C:
- ct := tm.Unix()
- if ct < lastCheck {
- lastCheck = ct
- continue
- }
- if ct-lastCheck < 1 {
- continue
- }
- checkTaskParallel()
- case <-tm.exitChan:
- ticker.Stop()
- close(tm.taskChan)
- tm.wg.Wait()
- return
- }
- }
- }
- func (tm *LongRunningTaskManager) worker() {
- for {
- select {
- case task, ok := <-tm.taskChan:
- if !ok || task == nil {
- return
- }
- if err := task(); err != nil {
- fmt.Println(err)
- }
- default:
- return
- }
- }
- }
- func (tm *LongRunningTaskManager) AddTask(task func() error) {
- tm.taskChan <- task
- tm.notifyCheck()
- }
- func (tm *LongRunningTaskManager) Close() {
- close(tm.exitChan)
- }
- func (tm *LongRunningTaskManager) Wait() {
- tm.wg.Wait()
- }
- func exampleTaskManager() {
- tm := NewTaskManager(10)
- for i := 0; i < 200; i++ {
- taskID := i
- tm.AddTask(func() error {
- n := rand.Int63n(5)
- time.Sleep(time.Second * time.Duration(n+1))
- fmt.Printf("Task %d is being processed\n", taskID)
- return nil
- })
- }
- tm.Close()
- tm.Wait()
- fmt.Println("the end")
- }
|