package batchtask import ( "context" "fmt" "gorm.io/gorm" "gorm.io/gorm/clause" "log" "sparkteam-dash/orm/model" "sparkteam-dash/pkg/config" "sparkteam-dash/pkg/db" "sync" "sync/atomic" "time" ) // TableData 带表名的数据 type TableData struct { TableName string // 表名 Data interface{} // 数据 } // BatchWriter 批量写入器 type BatchWriter struct { db *gorm.DB dataChan chan TableData stopChan chan struct{} wg sync.WaitGroup metrics *Metrics startTime time.Time mu sync.RWMutex //httpServer *http.Server ctx context.Context cancel context.CancelFunc flushChan chan struct{} // 配置 batchSize int flushTimeout time.Duration channelSize int } // Metrics 监控指标 type Metrics struct { TotalPushed int64 `json:"total_pushed"` TotalProcessed int64 `json:"total_processed"` TotalBatches int64 `json:"total_batches"` TotalErrors int64 `json:"total_errors"` TableMetrics map[string]int64 `json:"table_metrics"` // 表级别的统计 CurrentBatchSize int `json:"current_batch_size"` ChannelLength int `json:"channel_length"` ChannelCapacity int `json:"channel_capacity"` LastFlushTime time.Time `json:"last_flush_time"` } var Batch *BatchWriter // NewBatchWriter 创建新的批量写入器 func NewBatchWriter() *BatchWriter { batchSize := config.App.Batch.BatchSize flushTimeout := config.App.Batch.FlushTimeout channelSize := config.App.Batch.ChannelSize // 设置默认值 if batchSize <= 0 { batchSize = 1000 } if flushTimeout <= 0 { flushTimeout = 5 * time.Second } else { flushTimeout = flushTimeout * time.Second } if channelSize <= 0 { channelSize = batchSize * 2 } ctx, cancel := context.WithCancel(context.Background()) Batch = &BatchWriter{ dataChan: make(chan TableData, channelSize), stopChan: make(chan struct{}), startTime: time.Now(), ctx: ctx, cancel: cancel, batchSize: batchSize, flushTimeout: flushTimeout, channelSize: channelSize, metrics: &Metrics{ TableMetrics: make(map[string]int64), ChannelCapacity: channelSize, }, } // 启动处理协程 Batch.wg.Add(1) go Batch.process() return Batch } // process 处理数据 func (bw *BatchWriter) process() { defer bw.wg.Done() // 按表名分组的数据批次 batchByTable := make(map[string][]interface{}) ticker := time.NewTicker(bw.flushTimeout) defer ticker.Stop() for { select { case tableData, ok := <-bw.dataChan: if !ok { // 通道关闭,处理所有剩余数据 bw.flushAllBatches(batchByTable) return } atomic.AddInt64(&bw.metrics.TotalPushed, 1) // 按表名分组 if batchByTable[tableData.TableName] == nil { batchByTable[tableData.TableName] = make([]interface{}, 0, bw.batchSize) } batchByTable[tableData.TableName] = append(batchByTable[tableData.TableName], tableData.Data) // 更新监控指标 bw.mu.Lock() bw.metrics.ChannelLength = len(bw.dataChan) bw.metrics.CurrentBatchSize = bw.calculateTotalBatchSize(batchByTable) bw.mu.Unlock() // 检查是否有表达到批次大小 if bw.hasTableReachedBatchSize(batchByTable) { bw.flushAllBatches(batchByTable) batchByTable = make(map[string][]interface{}) } case <-ticker.C: // 定时刷新所有表的数据 if len(batchByTable) > 0 { bw.flushAllBatches(batchByTable) batchByTable = make(map[string][]interface{}) } case <-bw.stopChan: // 收到停止信号 if len(batchByTable) > 0 { bw.flushAllBatches(batchByTable) } return case <-bw.ctx.Done(): // 上下文取消 if len(batchByTable) > 0 { bw.flushAllBatches(batchByTable) } return } } } // flushAllBatches 批量处理所有表的数据 func (bw *BatchWriter) flushAllBatches(batches map[string][]interface{}) { if len(batches) == 0 { return } var err error for tableName, batch := range batches { if len(batch) == 0 { continue } if err = bw.createInBatches(tableName, batch); err != nil { continue } // 更新表级别统计 atomic.AddInt64(&bw.metrics.TotalProcessed, int64(len(batch))) bw.mu.Lock() bw.metrics.TableMetrics[tableName] += int64(len(batch)) bw.mu.Unlock() } if err != nil { atomic.AddInt64(&bw.metrics.TotalErrors, 1) log.Printf("Failed to process batches: %v", err) } else { atomic.AddInt64(&bw.metrics.TotalBatches, 1) bw.mu.Lock() bw.metrics.LastFlushTime = time.Now() bw.metrics.CurrentBatchSize = 0 bw.mu.Unlock() log.Printf("Successfully processed %d tables, total %d records", len(batches), bw.calculateTotalBatchSize(batches)) } } func (bw *BatchWriter) createInBatches(tableName string, batch []interface{}) error { switch tableName { case "login_log": data := make([]*model.LoginLog, 0, len(batch)) for _, item := range batch { data = append(data, item.(*model.LoginLog)) } if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil { return fmt.Errorf("table %s: %w", tableName, err) } case "ad_log": data := make([]*model.AdLog, 0, len(batch)) for _, item := range batch { data = append(data, item.(*model.AdLog)) } if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil { return fmt.Errorf("table %s: %w", tableName, err) } case "guild_log": data := make([]*model.GuideLog, 0, len(batch)) for _, item := range batch { data = append(data, item.(*model.GuideLog)) } if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil { return fmt.Errorf("table %s: %w", tableName, err) } case "battle_log": data := make([]*model.BattleLog, 0, len(batch)) for _, item := range batch { data = append(data, item.(*model.BattleLog)) } if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil { return fmt.Errorf("table %s: %w", tableName, err) } case "online_duration_log": data := make([]*model.OnlineDurationLog, 0, len(batch)) for _, item := range batch { data = append(data, item.(*model.OnlineDurationLog)) } if err := db.LogEngine().Table(tableName).Clauses(clause.OnConflict{UpdateAll: true}).CreateInBatches(data, bw.batchSize).Error; err != nil { return fmt.Errorf("table %s: %w", tableName, err) } default: return nil } return nil } // hasTableReachedBatchSize 检查是否有表达到批次大小 func (bw *BatchWriter) hasTableReachedBatchSize(batches map[string][]interface{}) bool { for _, batch := range batches { if len(batch) >= bw.batchSize { return true } } return false } // calculateTotalBatchSize 计算总批次大小 func (bw *BatchWriter) calculateTotalBatchSize(batches map[string][]interface{}) int { total := 0 for _, batch := range batches { total += len(batch) } return total } // Push 推送数据到指定表 func (bw *BatchWriter) Push(tableName string, data interface{}) { bw.dataChan <- TableData{ TableName: tableName, Data: data, } } // PushWithContext 带上下文的推送 func (bw *BatchWriter) PushWithContext(ctx context.Context, tableName string, data interface{}) error { select { case bw.dataChan <- TableData{TableName: tableName, Data: data}: return nil case <-ctx.Done(): return ctx.Err() } } // Close 关闭写入器 func (bw *BatchWriter) Close() { close(bw.dataChan) bw.wg.Wait() //bw.stopHTTPServer() } // Stop 停止写入器 func (bw *BatchWriter) Stop() { bw.cancel() close(bw.stopChan) bw.wg.Wait() //bw.stopHTTPServer() } // GetMetrics 获取监控指标 func (bw *BatchWriter) GetMetrics() Metrics { bw.mu.RLock() defer bw.mu.RUnlock() return Metrics{ TotalPushed: atomic.LoadInt64(&bw.metrics.TotalPushed), TotalProcessed: atomic.LoadInt64(&bw.metrics.TotalProcessed), TotalBatches: atomic.LoadInt64(&bw.metrics.TotalBatches), TotalErrors: atomic.LoadInt64(&bw.metrics.TotalErrors), TableMetrics: bw.copyTableMetrics(), CurrentBatchSize: bw.metrics.CurrentBatchSize, ChannelLength: len(bw.dataChan), ChannelCapacity: bw.metrics.ChannelCapacity, LastFlushTime: bw.metrics.LastFlushTime, } } // copyTableMetrics 复制表指标(线程安全) func (bw *BatchWriter) copyTableMetrics() map[string]int64 { result := make(map[string]int64) for k, v := range bw.metrics.TableMetrics { result[k] = v } return result } // startMetricsServer 启动监控指标服务器 //func (bw *BatchWriter) startMetricsServer(port int) { // router := mux.NewRouter() // router.HandleFunc("/metrics", bw.metricsHandler).Methods("GET") // router.HandleFunc("/health", bw.healthHandler).Methods("GET") // router.HandleFunc("/stats", bw.statsHandler).Methods("GET") // // addr := fmt.Sprintf(":%d", port) // bw.httpServer = &http.Server{ // Addr: addr, // Handler: router, // } // // log.Printf("Metrics server started on port %d", port) // if err := bw.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { // log.Printf("Metrics server error: %v", err) // } //} // metricsHandler 监控指标处理 //func (bw *BatchWriter) metricsHandler(w http.ResponseWriter, r *http.Request) { // metrics := bw.GetMetrics() // w.Header().Set("Content-Type", "application/json") // json.NewEncoder(w).Encode(metrics) //} // healthHandler 健康检查处理 //func (bw *BatchWriter) healthHandler(w http.ResponseWriter, r *http.Request) { // response := map[string]interface{}{ // "status": "healthy", // "timestamp": time.Now(), // "uptime": time.Since(bw.startTime).String(), // "channel": map[string]interface{}{ // "length": len(bw.dataChan), // "cap": cap(bw.dataChan), // }, // } // w.Header().Set("Content-Type", "application/json") // json.NewEncoder(w).Encode(response) //} // statsHandler 统计信息处理 //func (bw *BatchWriter) statsHandler(w http.ResponseWriter, r *http.Request) { // metrics := bw.GetMetrics() // stats := map[string]interface{}{ // "metrics": metrics, // "config": map[string]interface{}{ // "batch_size": bw.batchSize, // "flush_timeout": bw.flushTimeout.String(), // "channel_size": bw.channelSize, // }, // } // w.Header().Set("Content-Type", "application/json") // json.NewEncoder(w).Encode(stats) //} // stopHTTPServer 停止HTTP服务器 //func (bw *BatchWriter) stopHTTPServer() { // if bw.httpServer != nil { // ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // defer cancel() // bw.httpServer.Shutdown(ctx) // } //}