batch_task.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. package batchtask
  2. import (
  3. "context"
  4. "fmt"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "log"
  8. "sparkteam-dash/orm/model"
  9. "sparkteam-dash/pkg/config"
  10. "sparkteam-dash/pkg/db"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // TableData 带表名的数据
  16. type TableData struct {
  17. TableName string // 表名
  18. Data interface{} // 数据
  19. }
  20. // BatchWriter 批量写入器
  21. type BatchWriter struct {
  22. db *gorm.DB
  23. dataChan chan TableData
  24. stopChan chan struct{}
  25. wg sync.WaitGroup
  26. metrics *Metrics
  27. startTime time.Time
  28. mu sync.RWMutex
  29. //httpServer *http.Server
  30. ctx context.Context
  31. cancel context.CancelFunc
  32. flushChan chan struct{}
  33. // 配置
  34. batchSize int
  35. flushTimeout time.Duration
  36. channelSize int
  37. }
  38. // Metrics 监控指标
  39. type Metrics struct {
  40. TotalPushed int64 `json:"total_pushed"`
  41. TotalProcessed int64 `json:"total_processed"`
  42. TotalBatches int64 `json:"total_batches"`
  43. TotalErrors int64 `json:"total_errors"`
  44. TableMetrics map[string]int64 `json:"table_metrics"` // 表级别的统计
  45. CurrentBatchSize int `json:"current_batch_size"`
  46. ChannelLength int `json:"channel_length"`
  47. ChannelCapacity int `json:"channel_capacity"`
  48. LastFlushTime time.Time `json:"last_flush_time"`
  49. }
  50. var Batch *BatchWriter
  51. // NewBatchWriter 创建新的批量写入器
  52. func NewBatchWriter() *BatchWriter {
  53. batchSize := config.App.Batch.BatchSize
  54. flushTimeout := config.App.Batch.FlushTimeout
  55. channelSize := config.App.Batch.ChannelSize
  56. // 设置默认值
  57. if batchSize <= 0 {
  58. batchSize = 1000
  59. }
  60. if flushTimeout <= 0 {
  61. flushTimeout = 5 * time.Second
  62. } else {
  63. flushTimeout = flushTimeout * time.Second
  64. }
  65. if channelSize <= 0 {
  66. channelSize = batchSize * 2
  67. }
  68. ctx, cancel := context.WithCancel(context.Background())
  69. Batch = &BatchWriter{
  70. dataChan: make(chan TableData, channelSize),
  71. stopChan: make(chan struct{}),
  72. startTime: time.Now(),
  73. ctx: ctx,
  74. cancel: cancel,
  75. batchSize: batchSize,
  76. flushTimeout: flushTimeout,
  77. channelSize: channelSize,
  78. metrics: &Metrics{
  79. TableMetrics: make(map[string]int64),
  80. ChannelCapacity: channelSize,
  81. },
  82. }
  83. // 启动处理协程
  84. Batch.wg.Add(1)
  85. go Batch.process()
  86. return Batch
  87. }
  88. // process 处理数据
  89. func (bw *BatchWriter) process() {
  90. defer bw.wg.Done()
  91. // 按表名分组的数据批次
  92. batchByTable := make(map[string][]interface{})
  93. ticker := time.NewTicker(bw.flushTimeout)
  94. defer ticker.Stop()
  95. for {
  96. select {
  97. case tableData, ok := <-bw.dataChan:
  98. if !ok {
  99. // 通道关闭,处理所有剩余数据
  100. bw.flushAllBatches(batchByTable)
  101. return
  102. }
  103. atomic.AddInt64(&bw.metrics.TotalPushed, 1)
  104. // 按表名分组
  105. if batchByTable[tableData.TableName] == nil {
  106. batchByTable[tableData.TableName] = make([]interface{}, 0, bw.batchSize)
  107. }
  108. batchByTable[tableData.TableName] = append(batchByTable[tableData.TableName], tableData.Data)
  109. // 更新监控指标
  110. bw.mu.Lock()
  111. bw.metrics.ChannelLength = len(bw.dataChan)
  112. bw.metrics.CurrentBatchSize = bw.calculateTotalBatchSize(batchByTable)
  113. bw.mu.Unlock()
  114. // 检查是否有表达到批次大小
  115. if bw.hasTableReachedBatchSize(batchByTable) {
  116. bw.flushAllBatches(batchByTable)
  117. batchByTable = make(map[string][]interface{})
  118. }
  119. case <-ticker.C:
  120. // 定时刷新所有表的数据
  121. if len(batchByTable) > 0 {
  122. bw.flushAllBatches(batchByTable)
  123. batchByTable = make(map[string][]interface{})
  124. }
  125. case <-bw.stopChan:
  126. // 收到停止信号
  127. if len(batchByTable) > 0 {
  128. bw.flushAllBatches(batchByTable)
  129. }
  130. return
  131. case <-bw.ctx.Done():
  132. // 上下文取消
  133. if len(batchByTable) > 0 {
  134. bw.flushAllBatches(batchByTable)
  135. }
  136. return
  137. }
  138. }
  139. }
  140. // flushAllBatches 批量处理所有表的数据
  141. func (bw *BatchWriter) flushAllBatches(batches map[string][]interface{}) {
  142. if len(batches) == 0 {
  143. return
  144. }
  145. var err error
  146. for tableName, batch := range batches {
  147. if len(batch) == 0 {
  148. continue
  149. }
  150. if err = bw.createInBatches(tableName, batch); err != nil {
  151. continue
  152. }
  153. // 更新表级别统计
  154. atomic.AddInt64(&bw.metrics.TotalProcessed, int64(len(batch)))
  155. bw.mu.Lock()
  156. bw.metrics.TableMetrics[tableName] += int64(len(batch))
  157. bw.mu.Unlock()
  158. }
  159. if err != nil {
  160. atomic.AddInt64(&bw.metrics.TotalErrors, 1)
  161. log.Printf("Failed to process batches: %v", err)
  162. } else {
  163. atomic.AddInt64(&bw.metrics.TotalBatches, 1)
  164. bw.mu.Lock()
  165. bw.metrics.LastFlushTime = time.Now()
  166. bw.metrics.CurrentBatchSize = 0
  167. bw.mu.Unlock()
  168. log.Printf("Successfully processed %d tables, total %d records",
  169. len(batches), bw.calculateTotalBatchSize(batches))
  170. }
  171. }
  172. func (bw *BatchWriter) createInBatches(tableName string, batch []interface{}) error {
  173. switch tableName {
  174. case "login_log":
  175. data := make([]*model.LoginLog, 0, len(batch))
  176. for _, item := range batch {
  177. data = append(data, item.(*model.LoginLog))
  178. }
  179. if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  180. return fmt.Errorf("table %s: %w", tableName, err)
  181. }
  182. case "ad_log":
  183. data := make([]*model.AdLog, 0, len(batch))
  184. for _, item := range batch {
  185. data = append(data, item.(*model.AdLog))
  186. }
  187. if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  188. return fmt.Errorf("table %s: %w", tableName, err)
  189. }
  190. case "guild_log":
  191. data := make([]*model.GuideLog, 0, len(batch))
  192. for _, item := range batch {
  193. data = append(data, item.(*model.GuideLog))
  194. }
  195. if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  196. return fmt.Errorf("table %s: %w", tableName, err)
  197. }
  198. case "battle_log":
  199. data := make([]*model.BattleLog, 0, len(batch))
  200. for _, item := range batch {
  201. data = append(data, item.(*model.BattleLog))
  202. }
  203. if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  204. return fmt.Errorf("table %s: %w", tableName, err)
  205. }
  206. case "online_duration_log":
  207. data := make([]*model.OnlineDurationLog, 0, len(batch))
  208. for _, item := range batch {
  209. data = append(data, item.(*model.OnlineDurationLog))
  210. }
  211. if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
  212. return fmt.Errorf("table %s: %w", tableName, err)
  213. }
  214. default:
  215. return nil
  216. }
  217. return nil
  218. }
  219. // hasTableReachedBatchSize 检查是否有表达到批次大小
  220. func (bw *BatchWriter) hasTableReachedBatchSize(batches map[string][]interface{}) bool {
  221. for _, batch := range batches {
  222. if len(batch) >= bw.batchSize {
  223. return true
  224. }
  225. }
  226. return false
  227. }
  228. // calculateTotalBatchSize 计算总批次大小
  229. func (bw *BatchWriter) calculateTotalBatchSize(batches map[string][]interface{}) int {
  230. total := 0
  231. for _, batch := range batches {
  232. total += len(batch)
  233. }
  234. return total
  235. }
  236. // Push 推送数据到指定表
  237. func (bw *BatchWriter) Push(tableName string, data interface{}) {
  238. bw.dataChan <- TableData{
  239. TableName: tableName,
  240. Data: data,
  241. }
  242. }
  243. // PushWithContext 带上下文的推送
  244. func (bw *BatchWriter) PushWithContext(ctx context.Context, tableName string, data interface{}) error {
  245. select {
  246. case bw.dataChan <- TableData{TableName: tableName, Data: data}:
  247. return nil
  248. case <-ctx.Done():
  249. return ctx.Err()
  250. }
  251. }
  252. // Close 关闭写入器
  253. func (bw *BatchWriter) Close() {
  254. close(bw.dataChan)
  255. bw.wg.Wait()
  256. //bw.stopHTTPServer()
  257. }
  258. // Stop 停止写入器
  259. func (bw *BatchWriter) Stop() {
  260. bw.cancel()
  261. close(bw.stopChan)
  262. bw.wg.Wait()
  263. //bw.stopHTTPServer()
  264. }
  265. // GetMetrics 获取监控指标
  266. func (bw *BatchWriter) GetMetrics() Metrics {
  267. bw.mu.RLock()
  268. defer bw.mu.RUnlock()
  269. return Metrics{
  270. TotalPushed: atomic.LoadInt64(&bw.metrics.TotalPushed),
  271. TotalProcessed: atomic.LoadInt64(&bw.metrics.TotalProcessed),
  272. TotalBatches: atomic.LoadInt64(&bw.metrics.TotalBatches),
  273. TotalErrors: atomic.LoadInt64(&bw.metrics.TotalErrors),
  274. TableMetrics: bw.copyTableMetrics(),
  275. CurrentBatchSize: bw.metrics.CurrentBatchSize,
  276. ChannelLength: len(bw.dataChan),
  277. ChannelCapacity: bw.metrics.ChannelCapacity,
  278. LastFlushTime: bw.metrics.LastFlushTime,
  279. }
  280. }
  281. // copyTableMetrics 复制表指标(线程安全)
  282. func (bw *BatchWriter) copyTableMetrics() map[string]int64 {
  283. result := make(map[string]int64)
  284. for k, v := range bw.metrics.TableMetrics {
  285. result[k] = v
  286. }
  287. return result
  288. }
  289. // startMetricsServer 启动监控指标服务器
  290. //func (bw *BatchWriter) startMetricsServer(port int) {
  291. // router := mux.NewRouter()
  292. // router.HandleFunc("/metrics", bw.metricsHandler).Methods("GET")
  293. // router.HandleFunc("/health", bw.healthHandler).Methods("GET")
  294. // router.HandleFunc("/stats", bw.statsHandler).Methods("GET")
  295. //
  296. // addr := fmt.Sprintf(":%d", port)
  297. // bw.httpServer = &http.Server{
  298. // Addr: addr,
  299. // Handler: router,
  300. // }
  301. //
  302. // log.Printf("Metrics server started on port %d", port)
  303. // if err := bw.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  304. // log.Printf("Metrics server error: %v", err)
  305. // }
  306. //}
  307. // metricsHandler 监控指标处理
  308. //func (bw *BatchWriter) metricsHandler(w http.ResponseWriter, r *http.Request) {
  309. // metrics := bw.GetMetrics()
  310. // w.Header().Set("Content-Type", "application/json")
  311. // json.NewEncoder(w).Encode(metrics)
  312. //}
  313. // healthHandler 健康检查处理
  314. //func (bw *BatchWriter) healthHandler(w http.ResponseWriter, r *http.Request) {
  315. // response := map[string]interface{}{
  316. // "status": "healthy",
  317. // "timestamp": time.Now(),
  318. // "uptime": time.Since(bw.startTime).String(),
  319. // "channel": map[string]interface{}{
  320. // "length": len(bw.dataChan),
  321. // "cap": cap(bw.dataChan),
  322. // },
  323. // }
  324. // w.Header().Set("Content-Type", "application/json")
  325. // json.NewEncoder(w).Encode(response)
  326. //}
  327. // statsHandler 统计信息处理
  328. //func (bw *BatchWriter) statsHandler(w http.ResponseWriter, r *http.Request) {
  329. // metrics := bw.GetMetrics()
  330. // stats := map[string]interface{}{
  331. // "metrics": metrics,
  332. // "config": map[string]interface{}{
  333. // "batch_size": bw.batchSize,
  334. // "flush_timeout": bw.flushTimeout.String(),
  335. // "channel_size": bw.channelSize,
  336. // },
  337. // }
  338. // w.Header().Set("Content-Type", "application/json")
  339. // json.NewEncoder(w).Encode(stats)
  340. //}
  341. // stopHTTPServer 停止HTTP服务器
  342. //func (bw *BatchWriter) stopHTTPServer() {
  343. // if bw.httpServer != nil {
  344. // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  345. // defer cancel()
  346. // bw.httpServer.Shutdown(ctx)
  347. // }
  348. //}