123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390 |
- package batchtask
- import (
- "context"
- "fmt"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- "log"
- "sparkteam-dash/orm/model"
- "sparkteam-dash/pkg/config"
- "sparkteam-dash/pkg/db"
- "sync"
- "sync/atomic"
- "time"
- )
- // TableData 带表名的数据
- type TableData struct {
- TableName string // 表名
- Data interface{} // 数据
- }
- // BatchWriter 批量写入器
- type BatchWriter struct {
- db *gorm.DB
- dataChan chan TableData
- stopChan chan struct{}
- wg sync.WaitGroup
- metrics *Metrics
- startTime time.Time
- mu sync.RWMutex
- //httpServer *http.Server
- ctx context.Context
- cancel context.CancelFunc
- flushChan chan struct{}
- // 配置
- batchSize int
- flushTimeout time.Duration
- channelSize int
- }
- // Metrics 监控指标
- type Metrics struct {
- TotalPushed int64 `json:"total_pushed"`
- TotalProcessed int64 `json:"total_processed"`
- TotalBatches int64 `json:"total_batches"`
- TotalErrors int64 `json:"total_errors"`
- TableMetrics map[string]int64 `json:"table_metrics"` // 表级别的统计
- CurrentBatchSize int `json:"current_batch_size"`
- ChannelLength int `json:"channel_length"`
- ChannelCapacity int `json:"channel_capacity"`
- LastFlushTime time.Time `json:"last_flush_time"`
- }
- var Batch *BatchWriter
- // NewBatchWriter 创建新的批量写入器
- func NewBatchWriter() *BatchWriter {
- batchSize := config.App.Batch.BatchSize
- flushTimeout := config.App.Batch.FlushTimeout
- channelSize := config.App.Batch.ChannelSize
- // 设置默认值
- if batchSize <= 0 {
- batchSize = 1000
- }
- if flushTimeout <= 0 {
- flushTimeout = 5 * time.Second
- } else {
- flushTimeout = flushTimeout * time.Second
- }
- if channelSize <= 0 {
- channelSize = batchSize * 2
- }
- ctx, cancel := context.WithCancel(context.Background())
- Batch = &BatchWriter{
- dataChan: make(chan TableData, channelSize),
- stopChan: make(chan struct{}),
- startTime: time.Now(),
- ctx: ctx,
- cancel: cancel,
- batchSize: batchSize,
- flushTimeout: flushTimeout,
- channelSize: channelSize,
- metrics: &Metrics{
- TableMetrics: make(map[string]int64),
- ChannelCapacity: channelSize,
- },
- }
- // 启动处理协程
- Batch.wg.Add(1)
- go Batch.process()
- return Batch
- }
- // process 处理数据
- func (bw *BatchWriter) process() {
- defer bw.wg.Done()
- // 按表名分组的数据批次
- batchByTable := make(map[string][]interface{})
- ticker := time.NewTicker(bw.flushTimeout)
- defer ticker.Stop()
- for {
- select {
- case tableData, ok := <-bw.dataChan:
- if !ok {
- // 通道关闭,处理所有剩余数据
- bw.flushAllBatches(batchByTable)
- return
- }
- atomic.AddInt64(&bw.metrics.TotalPushed, 1)
- // 按表名分组
- if batchByTable[tableData.TableName] == nil {
- batchByTable[tableData.TableName] = make([]interface{}, 0, bw.batchSize)
- }
- batchByTable[tableData.TableName] = append(batchByTable[tableData.TableName], tableData.Data)
- // 更新监控指标
- bw.mu.Lock()
- bw.metrics.ChannelLength = len(bw.dataChan)
- bw.metrics.CurrentBatchSize = bw.calculateTotalBatchSize(batchByTable)
- bw.mu.Unlock()
- // 检查是否有表达到批次大小
- if bw.hasTableReachedBatchSize(batchByTable) {
- bw.flushAllBatches(batchByTable)
- batchByTable = make(map[string][]interface{})
- }
- case <-ticker.C:
- // 定时刷新所有表的数据
- if len(batchByTable) > 0 {
- bw.flushAllBatches(batchByTable)
- batchByTable = make(map[string][]interface{})
- }
- case <-bw.stopChan:
- // 收到停止信号
- if len(batchByTable) > 0 {
- bw.flushAllBatches(batchByTable)
- }
- return
- case <-bw.ctx.Done():
- // 上下文取消
- if len(batchByTable) > 0 {
- bw.flushAllBatches(batchByTable)
- }
- return
- }
- }
- }
- // flushAllBatches 批量处理所有表的数据
- func (bw *BatchWriter) flushAllBatches(batches map[string][]interface{}) {
- if len(batches) == 0 {
- return
- }
- var err error
- for tableName, batch := range batches {
- if len(batch) == 0 {
- continue
- }
- if err = bw.createInBatches(tableName, batch); err != nil {
- continue
- }
- // 更新表级别统计
- atomic.AddInt64(&bw.metrics.TotalProcessed, int64(len(batch)))
- bw.mu.Lock()
- bw.metrics.TableMetrics[tableName] += int64(len(batch))
- bw.mu.Unlock()
- }
- if err != nil {
- atomic.AddInt64(&bw.metrics.TotalErrors, 1)
- log.Printf("Failed to process batches: %v", err)
- } else {
- atomic.AddInt64(&bw.metrics.TotalBatches, 1)
- bw.mu.Lock()
- bw.metrics.LastFlushTime = time.Now()
- bw.metrics.CurrentBatchSize = 0
- bw.mu.Unlock()
- log.Printf("Successfully processed %d tables, total %d records",
- len(batches), bw.calculateTotalBatchSize(batches))
- }
- }
- func (bw *BatchWriter) createInBatches(tableName string, batch []interface{}) error {
- switch tableName {
- case "login_log":
- data := make([]*model.LoginLog, 0, len(batch))
- for _, item := range batch {
- data = append(data, item.(*model.LoginLog))
- }
- if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
- return fmt.Errorf("table %s: %w", tableName, err)
- }
- case "ad_log":
- data := make([]*model.AdLog, 0, len(batch))
- for _, item := range batch {
- data = append(data, item.(*model.AdLog))
- }
- if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
- return fmt.Errorf("table %s: %w", tableName, err)
- }
- case "guild_log":
- data := make([]*model.GuideLog, 0, len(batch))
- for _, item := range batch {
- data = append(data, item.(*model.GuideLog))
- }
- if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
- return fmt.Errorf("table %s: %w", tableName, err)
- }
- case "battle_log":
- data := make([]*model.BattleLog, 0, len(batch))
- for _, item := range batch {
- data = append(data, item.(*model.BattleLog))
- }
- if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
- return fmt.Errorf("table %s: %w", tableName, err)
- }
- case "online_duration_log":
- data := make([]*model.OnlineDurationLog, 0, len(batch))
- for _, item := range batch {
- data = append(data, item.(*model.OnlineDurationLog))
- }
- if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil {
- return fmt.Errorf("table %s: %w", tableName, err)
- }
- default:
- return nil
- }
- return nil
- }
- // hasTableReachedBatchSize 检查是否有表达到批次大小
- func (bw *BatchWriter) hasTableReachedBatchSize(batches map[string][]interface{}) bool {
- for _, batch := range batches {
- if len(batch) >= bw.batchSize {
- return true
- }
- }
- return false
- }
- // calculateTotalBatchSize 计算总批次大小
- func (bw *BatchWriter) calculateTotalBatchSize(batches map[string][]interface{}) int {
- total := 0
- for _, batch := range batches {
- total += len(batch)
- }
- return total
- }
- // Push 推送数据到指定表
- func (bw *BatchWriter) Push(tableName string, data interface{}) {
- bw.dataChan <- TableData{
- TableName: tableName,
- Data: data,
- }
- }
- // PushWithContext 带上下文的推送
- func (bw *BatchWriter) PushWithContext(ctx context.Context, tableName string, data interface{}) error {
- select {
- case bw.dataChan <- TableData{TableName: tableName, Data: data}:
- return nil
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- // Close 关闭写入器
- func (bw *BatchWriter) Close() {
- close(bw.dataChan)
- bw.wg.Wait()
- //bw.stopHTTPServer()
- }
- // Stop 停止写入器
- func (bw *BatchWriter) Stop() {
- bw.cancel()
- close(bw.stopChan)
- bw.wg.Wait()
- //bw.stopHTTPServer()
- }
- // GetMetrics 获取监控指标
- func (bw *BatchWriter) GetMetrics() Metrics {
- bw.mu.RLock()
- defer bw.mu.RUnlock()
- return Metrics{
- TotalPushed: atomic.LoadInt64(&bw.metrics.TotalPushed),
- TotalProcessed: atomic.LoadInt64(&bw.metrics.TotalProcessed),
- TotalBatches: atomic.LoadInt64(&bw.metrics.TotalBatches),
- TotalErrors: atomic.LoadInt64(&bw.metrics.TotalErrors),
- TableMetrics: bw.copyTableMetrics(),
- CurrentBatchSize: bw.metrics.CurrentBatchSize,
- ChannelLength: len(bw.dataChan),
- ChannelCapacity: bw.metrics.ChannelCapacity,
- LastFlushTime: bw.metrics.LastFlushTime,
- }
- }
- // copyTableMetrics 复制表指标(线程安全)
- func (bw *BatchWriter) copyTableMetrics() map[string]int64 {
- result := make(map[string]int64)
- for k, v := range bw.metrics.TableMetrics {
- result[k] = v
- }
- return result
- }
- // startMetricsServer 启动监控指标服务器
- //func (bw *BatchWriter) startMetricsServer(port int) {
- // router := mux.NewRouter()
- // router.HandleFunc("/metrics", bw.metricsHandler).Methods("GET")
- // router.HandleFunc("/health", bw.healthHandler).Methods("GET")
- // router.HandleFunc("/stats", bw.statsHandler).Methods("GET")
- //
- // addr := fmt.Sprintf(":%d", port)
- // bw.httpServer = &http.Server{
- // Addr: addr,
- // Handler: router,
- // }
- //
- // log.Printf("Metrics server started on port %d", port)
- // if err := bw.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- // log.Printf("Metrics server error: %v", err)
- // }
- //}
- // metricsHandler 监控指标处理
- //func (bw *BatchWriter) metricsHandler(w http.ResponseWriter, r *http.Request) {
- // metrics := bw.GetMetrics()
- // w.Header().Set("Content-Type", "application/json")
- // json.NewEncoder(w).Encode(metrics)
- //}
- // healthHandler 健康检查处理
- //func (bw *BatchWriter) healthHandler(w http.ResponseWriter, r *http.Request) {
- // response := map[string]interface{}{
- // "status": "healthy",
- // "timestamp": time.Now(),
- // "uptime": time.Since(bw.startTime).String(),
- // "channel": map[string]interface{}{
- // "length": len(bw.dataChan),
- // "cap": cap(bw.dataChan),
- // },
- // }
- // w.Header().Set("Content-Type", "application/json")
- // json.NewEncoder(w).Encode(response)
- //}
- // statsHandler 统计信息处理
- //func (bw *BatchWriter) statsHandler(w http.ResponseWriter, r *http.Request) {
- // metrics := bw.GetMetrics()
- // stats := map[string]interface{}{
- // "metrics": metrics,
- // "config": map[string]interface{}{
- // "batch_size": bw.batchSize,
- // "flush_timeout": bw.flushTimeout.String(),
- // "channel_size": bw.channelSize,
- // },
- // }
- // w.Header().Set("Content-Type", "application/json")
- // json.NewEncoder(w).Encode(stats)
- //}
- // stopHTTPServer 停止HTTP服务器
- //func (bw *BatchWriter) stopHTTPServer() {
- // if bw.httpServer != nil {
- // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- // defer cancel()
- // bw.httpServer.Shutdown(ctx)
- // }
- //}
|