batch_task.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. package batchtask
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/goccy/go-json"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. "log"
  9. "sparkteam-dash/pkg/config"
  10. "sparkteam-dash/pkg/db"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // TableData 带表名的数据
  16. type TableData struct {
  17. TableName string // 表名
  18. Data interface{} // 数据
  19. }
  20. // BatchWriter 批量写入器
  21. type BatchWriter struct {
  22. db *gorm.DB
  23. dataChan chan TableData
  24. stopChan chan struct{}
  25. wg sync.WaitGroup
  26. metrics *Metrics
  27. startTime time.Time
  28. mu sync.RWMutex
  29. //httpServer *http.Server
  30. ctx context.Context
  31. cancel context.CancelFunc
  32. flushChan chan struct{}
  33. // 配置
  34. batchSize int
  35. flushTimeout time.Duration
  36. channelSize int
  37. }
  38. // Metrics 监控指标
  39. type Metrics struct {
  40. TotalPushed int64 `json:"total_pushed"`
  41. TotalProcessed int64 `json:"total_processed"`
  42. TotalBatches int64 `json:"total_batches"`
  43. TotalErrors int64 `json:"total_errors"`
  44. TableMetrics map[string]int64 `json:"table_metrics"` // 表级别的统计
  45. CurrentBatchSize int `json:"current_batch_size"`
  46. ChannelLength int `json:"channel_length"`
  47. ChannelCapacity int `json:"channel_capacity"`
  48. LastFlushTime time.Time `json:"last_flush_time"`
  49. }
  50. var Batch *BatchWriter
  51. // NewBatchWriter 创建新的批量写入器
  52. func NewBatchWriter() *BatchWriter {
  53. batchSize := config.App.Batch.BatchSize
  54. flushTimeout := config.App.Batch.FlushTimeout
  55. channelSize := config.App.Batch.ChannelSize
  56. // 设置默认值
  57. if batchSize <= 0 {
  58. batchSize = 1000
  59. }
  60. if flushTimeout <= 0 {
  61. flushTimeout = 5 * time.Second
  62. } else {
  63. flushTimeout = flushTimeout * time.Second
  64. }
  65. if channelSize <= 0 {
  66. channelSize = batchSize * 2
  67. }
  68. ctx, cancel := context.WithCancel(context.Background())
  69. Batch = &BatchWriter{
  70. dataChan: make(chan TableData, channelSize),
  71. stopChan: make(chan struct{}),
  72. startTime: time.Now(),
  73. ctx: ctx,
  74. cancel: cancel,
  75. batchSize: batchSize,
  76. flushTimeout: flushTimeout,
  77. channelSize: channelSize,
  78. metrics: &Metrics{
  79. TableMetrics: make(map[string]int64),
  80. ChannelCapacity: channelSize,
  81. },
  82. }
  83. // 启动处理协程
  84. Batch.wg.Add(1)
  85. go Batch.process()
  86. return Batch
  87. }
  88. // process 处理数据
  89. func (bw *BatchWriter) process() {
  90. defer bw.wg.Done()
  91. // 按表名分组的数据批次
  92. batchByTable := make(map[string][]interface{})
  93. ticker := time.NewTicker(bw.flushTimeout)
  94. defer ticker.Stop()
  95. for {
  96. select {
  97. case tableData, ok := <-bw.dataChan:
  98. if !ok {
  99. // 通道关闭,处理所有剩余数据
  100. bw.flushAllBatches(batchByTable)
  101. return
  102. }
  103. atomic.AddInt64(&bw.metrics.TotalPushed, 1)
  104. // 按表名分组
  105. if batchByTable[tableData.TableName] == nil {
  106. batchByTable[tableData.TableName] = make([]interface{}, 0, bw.batchSize)
  107. }
  108. batchByTable[tableData.TableName] = append(batchByTable[tableData.TableName], tableData.Data)
  109. // 更新监控指标
  110. bw.mu.Lock()
  111. bw.metrics.ChannelLength = len(bw.dataChan)
  112. bw.metrics.CurrentBatchSize = bw.calculateTotalBatchSize(batchByTable)
  113. bw.mu.Unlock()
  114. // 检查是否有表达到批次大小
  115. if bw.hasTableReachedBatchSize(batchByTable) {
  116. bw.flushAllBatches(batchByTable)
  117. batchByTable = make(map[string][]interface{})
  118. }
  119. case <-ticker.C:
  120. // 定时刷新所有表的数据
  121. fmt.Println(time.Now())
  122. if len(batchByTable) > 0 {
  123. bw.flushAllBatches(batchByTable)
  124. batchByTable = make(map[string][]interface{})
  125. }
  126. case <-bw.stopChan:
  127. // 收到停止信号
  128. if len(batchByTable) > 0 {
  129. bw.flushAllBatches(batchByTable)
  130. }
  131. return
  132. case <-bw.ctx.Done():
  133. // 上下文取消
  134. if len(batchByTable) > 0 {
  135. bw.flushAllBatches(batchByTable)
  136. }
  137. return
  138. }
  139. }
  140. }
  141. // flushAllBatches 批量处理所有表的数据
  142. func (bw *BatchWriter) flushAllBatches(batches map[string][]interface{}) {
  143. if len(batches) == 0 {
  144. return
  145. }
  146. // 使用事务处理所有表的数据
  147. err := db.LogEngine().Transaction(func(tx *gorm.DB) error {
  148. for tableName, batch := range batches {
  149. if len(batch) == 0 {
  150. continue
  151. }
  152. data := make([]map[string]interface{}, 0, len(batch))
  153. dataByte, _ := json.Marshal(batch)
  154. _ = json.Unmarshal(dataByte, &data)
  155. // 按表名进行批量插入
  156. if err := tx.Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  157. return fmt.Errorf("table %s: %w", tableName, err)
  158. }
  159. // 更新表级别统计
  160. atomic.AddInt64(&bw.metrics.TotalProcessed, int64(len(batch)))
  161. bw.mu.Lock()
  162. bw.metrics.TableMetrics[tableName] += int64(len(batch))
  163. bw.mu.Unlock()
  164. }
  165. return nil
  166. })
  167. if err != nil {
  168. atomic.AddInt64(&bw.metrics.TotalErrors, 1)
  169. log.Printf("Failed to process batches: %v", err)
  170. } else {
  171. atomic.AddInt64(&bw.metrics.TotalBatches, 1)
  172. bw.mu.Lock()
  173. bw.metrics.LastFlushTime = time.Now()
  174. bw.metrics.CurrentBatchSize = 0
  175. bw.mu.Unlock()
  176. log.Printf("Successfully processed %d tables, total %d records",
  177. len(batches), bw.calculateTotalBatchSize(batches))
  178. }
  179. }
  180. // hasTableReachedBatchSize 检查是否有表达到批次大小
  181. func (bw *BatchWriter) hasTableReachedBatchSize(batches map[string][]interface{}) bool {
  182. for _, batch := range batches {
  183. if len(batch) >= bw.batchSize {
  184. return true
  185. }
  186. }
  187. return false
  188. }
  189. // calculateTotalBatchSize 计算总批次大小
  190. func (bw *BatchWriter) calculateTotalBatchSize(batches map[string][]interface{}) int {
  191. total := 0
  192. for _, batch := range batches {
  193. total += len(batch)
  194. }
  195. return total
  196. }
  197. // Push 推送数据到指定表
  198. func (bw *BatchWriter) Push(tableName string, data interface{}) {
  199. bw.dataChan <- TableData{
  200. TableName: tableName,
  201. Data: data,
  202. }
  203. }
  204. // PushWithContext 带上下文的推送
  205. func (bw *BatchWriter) PushWithContext(ctx context.Context, tableName string, data interface{}) error {
  206. select {
  207. case bw.dataChan <- TableData{TableName: tableName, Data: data}:
  208. return nil
  209. case <-ctx.Done():
  210. return ctx.Err()
  211. }
  212. }
  213. // Close 关闭写入器
  214. func (bw *BatchWriter) Close() {
  215. close(bw.dataChan)
  216. bw.wg.Wait()
  217. //bw.stopHTTPServer()
  218. }
  219. // Stop 停止写入器
  220. func (bw *BatchWriter) Stop() {
  221. bw.cancel()
  222. close(bw.stopChan)
  223. bw.wg.Wait()
  224. //bw.stopHTTPServer()
  225. }
  226. // GetMetrics 获取监控指标
  227. func (bw *BatchWriter) GetMetrics() Metrics {
  228. bw.mu.RLock()
  229. defer bw.mu.RUnlock()
  230. return Metrics{
  231. TotalPushed: atomic.LoadInt64(&bw.metrics.TotalPushed),
  232. TotalProcessed: atomic.LoadInt64(&bw.metrics.TotalProcessed),
  233. TotalBatches: atomic.LoadInt64(&bw.metrics.TotalBatches),
  234. TotalErrors: atomic.LoadInt64(&bw.metrics.TotalErrors),
  235. TableMetrics: bw.copyTableMetrics(),
  236. CurrentBatchSize: bw.metrics.CurrentBatchSize,
  237. ChannelLength: len(bw.dataChan),
  238. ChannelCapacity: bw.metrics.ChannelCapacity,
  239. LastFlushTime: bw.metrics.LastFlushTime,
  240. }
  241. }
  242. // copyTableMetrics 复制表指标(线程安全)
  243. func (bw *BatchWriter) copyTableMetrics() map[string]int64 {
  244. result := make(map[string]int64)
  245. for k, v := range bw.metrics.TableMetrics {
  246. result[k] = v
  247. }
  248. return result
  249. }
  250. // startMetricsServer 启动监控指标服务器
  251. //func (bw *BatchWriter) startMetricsServer(port int) {
  252. // router := mux.NewRouter()
  253. // router.HandleFunc("/metrics", bw.metricsHandler).Methods("GET")
  254. // router.HandleFunc("/health", bw.healthHandler).Methods("GET")
  255. // router.HandleFunc("/stats", bw.statsHandler).Methods("GET")
  256. //
  257. // addr := fmt.Sprintf(":%d", port)
  258. // bw.httpServer = &http.Server{
  259. // Addr: addr,
  260. // Handler: router,
  261. // }
  262. //
  263. // log.Printf("Metrics server started on port %d", port)
  264. // if err := bw.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  265. // log.Printf("Metrics server error: %v", err)
  266. // }
  267. //}
  268. // metricsHandler 监控指标处理
  269. //func (bw *BatchWriter) metricsHandler(w http.ResponseWriter, r *http.Request) {
  270. // metrics := bw.GetMetrics()
  271. // w.Header().Set("Content-Type", "application/json")
  272. // json.NewEncoder(w).Encode(metrics)
  273. //}
  274. // healthHandler 健康检查处理
  275. //func (bw *BatchWriter) healthHandler(w http.ResponseWriter, r *http.Request) {
  276. // response := map[string]interface{}{
  277. // "status": "healthy",
  278. // "timestamp": time.Now(),
  279. // "uptime": time.Since(bw.startTime).String(),
  280. // "channel": map[string]interface{}{
  281. // "length": len(bw.dataChan),
  282. // "cap": cap(bw.dataChan),
  283. // },
  284. // }
  285. // w.Header().Set("Content-Type", "application/json")
  286. // json.NewEncoder(w).Encode(response)
  287. //}
  288. // statsHandler 统计信息处理
  289. //func (bw *BatchWriter) statsHandler(w http.ResponseWriter, r *http.Request) {
  290. // metrics := bw.GetMetrics()
  291. // stats := map[string]interface{}{
  292. // "metrics": metrics,
  293. // "config": map[string]interface{}{
  294. // "batch_size": bw.batchSize,
  295. // "flush_timeout": bw.flushTimeout.String(),
  296. // "channel_size": bw.channelSize,
  297. // },
  298. // }
  299. // w.Header().Set("Content-Type", "application/json")
  300. // json.NewEncoder(w).Encode(stats)
  301. //}
  302. // stopHTTPServer 停止HTTP服务器
  303. //func (bw *BatchWriter) stopHTTPServer() {
  304. // if bw.httpServer != nil {
  305. // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  306. // defer cancel()
  307. // bw.httpServer.Shutdown(ctx)
  308. // }
  309. //}