task.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. package goroutine
  2. import (
  3. "fmt"
  4. "math/rand"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type LongRunningTaskManager struct {
  10. maxConcurrent int
  11. workerNum int32
  12. taskChan chan func() error
  13. exitChan chan struct{}
  14. checkChan chan struct{}
  15. wg sync.WaitGroup
  16. }
  17. func NewTaskManager(maxConcurrent int) *LongRunningTaskManager {
  18. if maxConcurrent <= 0 {
  19. maxConcurrent = 1
  20. }
  21. tm := &LongRunningTaskManager{
  22. taskChan: make(chan func() error, 500),
  23. maxConcurrent: maxConcurrent,
  24. exitChan: make(chan struct{}),
  25. checkChan: make(chan struct{}, 1),
  26. }
  27. go tm.taskParallelChecker()
  28. return tm
  29. }
  30. func (tm *LongRunningTaskManager) notifyCheck() {
  31. select {
  32. case tm.checkChan <- struct{}{}:
  33. default:
  34. }
  35. }
  36. func (tm *LongRunningTaskManager) taskParallelChecker() {
  37. checkTaskParallel := func() {
  38. // 检查是否需要怎加协程
  39. // 检查任务数量,单协程一秒处理20个,如果现有协程一秒内处理不完,就增加新协程
  40. // 协程数量大于最大并发数就不用加了
  41. taskNum := len(tm.taskChan)
  42. if taskNum <= 0 {
  43. return
  44. }
  45. wn := int(atomic.LoadInt32(&tm.workerNum))
  46. if wn >= tm.maxConcurrent {
  47. return
  48. }
  49. if wn > 0 && taskNum/20 < wn {
  50. return
  51. }
  52. // 增加协程
  53. tm.wg.Add(1)
  54. atomic.AddInt32(&tm.workerNum, 1)
  55. go func() {
  56. defer tm.wg.Done()
  57. defer func() {
  58. if err := recover(); err != nil {
  59. fmt.Println(err)
  60. }
  61. atomic.AddInt32(&tm.workerNum, -1)
  62. tm.notifyCheck()
  63. }()
  64. tm.worker()
  65. }()
  66. }
  67. var lastCheck int64
  68. ticker := time.NewTicker(time.Second)
  69. for {
  70. select {
  71. case <-tm.checkChan:
  72. ct := time.Now().Unix()
  73. if ct < lastCheck {
  74. lastCheck = ct
  75. continue
  76. }
  77. if ct-lastCheck < 1 {
  78. continue
  79. }
  80. checkTaskParallel()
  81. case tm := <-ticker.C:
  82. ct := tm.Unix()
  83. if ct < lastCheck {
  84. lastCheck = ct
  85. continue
  86. }
  87. if ct-lastCheck < 1 {
  88. continue
  89. }
  90. checkTaskParallel()
  91. case <-tm.exitChan:
  92. ticker.Stop()
  93. close(tm.taskChan)
  94. tm.wg.Wait()
  95. return
  96. }
  97. }
  98. }
  99. func (tm *LongRunningTaskManager) worker() {
  100. for {
  101. select {
  102. case task, ok := <-tm.taskChan:
  103. if !ok || task == nil {
  104. return
  105. }
  106. if err := task(); err != nil {
  107. fmt.Println(err)
  108. }
  109. default:
  110. return
  111. }
  112. }
  113. }
  114. func (tm *LongRunningTaskManager) AddTask(task func() error) {
  115. tm.taskChan <- task
  116. tm.notifyCheck()
  117. }
  118. func (tm *LongRunningTaskManager) Close() {
  119. close(tm.exitChan)
  120. }
  121. func (tm *LongRunningTaskManager) Wait() {
  122. tm.wg.Wait()
  123. }
  124. func exampleTaskManager() {
  125. tm := NewTaskManager(10)
  126. for i := 0; i < 200; i++ {
  127. taskID := i
  128. tm.AddTask(func() error {
  129. n := rand.Int63n(5)
  130. time.Sleep(time.Second * time.Duration(n+1))
  131. fmt.Printf("Task %d is being processed\n", taskID)
  132. return nil
  133. })
  134. }
  135. tm.Close()
  136. tm.Wait()
  137. fmt.Println("the end")
  138. }