package queue import ( "context" "database/sql" "encoding/json" "fmt" "github.com/nsqio/go-diskqueue" "sparkteam-dash/pkg/logger" "sync" "time" ) // DataItem 表示要存储的数据结构 type DataItem struct { ID int `json:"id"` Data string `json:"data"` Timestamp time.Time `json:"timestamp"` } // BatchDBWriter 批量数据库写入器 type BatchDBWriter struct { queue diskqueue.Interface db *sql.DB // 数据库连接 batchSize int // 批量大小 timeout time.Duration // 批量超时时间 mu sync.Mutex batch []DataItem lastPush time.Time } // NewBatchDBWriter 创建新的批量写入器 func NewBatchDBWriter(dataPath string, db *sql.DB, batchSize int, timeout time.Duration) *BatchDBWriter { writer := &BatchDBWriter{ db: db, batchSize: batchSize, timeout: timeout, batch: make([]DataItem, 0, batchSize), } // 初始化磁盘队列 writer.queue = diskqueue.New( "batch-db-writer", // 名称 dataPath, // 数据路径 1024*1024, // 最大文件大小 (1MB) 4, // 最小消息大小 1<<20, // 最大消息大小 (1MB) 1000, // 同步间隔 10*time.Second, writer.Log, // 日志接口 ) return writer } // Write 写入数据到队列 func (w *BatchDBWriter) Write(data DataItem) error { jsonData, err := json.Marshal(data) if err != nil { return fmt.Errorf("marshal data error: %v", err) } return w.queue.Put(jsonData) } // HandleMessage 处理从队列中读取的消息(实现 diskqueue.Logger 接口) func (w *BatchDBWriter) HandleMessage(message []byte) error { var data DataItem if err := json.Unmarshal(message, &data); err != nil { logger.Errorf("unmarshal message error: %v", err) return nil // 跳过错误消息 } w.mu.Lock() defer w.mu.Unlock() w.batch = append(w.batch, data) w.lastPush = time.Now() // 如果达到批量大小或超时,则写入数据库 if len(w.batch) >= w.batchSize || time.Since(w.lastPush) > w.timeout { if err := w.flushBatch(); err != nil { return err } } return nil } // flushBatch 将当前批次数据写入数据库 func (w *BatchDBWriter) flushBatch() error { if len(w.batch) == 0 { return nil } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 开始数据库事务 tx, err := w.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin transaction error: %v", err) } defer tx.Rollback() // 准备批量插入语句(根据实际数据库调整) stmt, err := tx.PrepareContext(ctx, "INSERT INTO data_table (id, data, timestamp) VALUES ($1, $2, $3)") if err != nil { return fmt.Errorf("prepare statement error: %v", err) } defer stmt.Close() // 执行批量插入 for _, item := range w.batch { _, err := stmt.ExecContext(ctx, item.ID, item.Data, item.Timestamp) if err != nil { return fmt.Errorf("execute statement error: %v", err) } } // 提交事务 if err := tx.Commit(); err != nil { return fmt.Errorf("commit transaction error: %v", err) } logger.Infof("成功写入 %d 条数据到数据库", len(w.batch)) w.batch = w.batch[:0] // 清空批次 w.lastPush = time.Now() return nil } // Log 实现 diskqueue.Logger 接口 func (w *BatchDBWriter) Log(level diskqueue.LogLevel, msg string, args ...interface{}) { logger.Errorf(fmt.Sprintf("[%s] %s", level, msg), args...) } // Start 启动批量写入器 func (w *BatchDBWriter) Start() { // 启动定时刷新(防止小批量数据长时间不写入) go func() { ticker := time.NewTicker(w.timeout) defer ticker.Stop() for range ticker.C { w.mu.Lock() if len(w.batch) > 0 && time.Since(w.lastPush) > w.timeout { if err := w.flushBatch(); err != nil { logger.Errorf("定时刷新批次失败: %v", err) } } w.mu.Unlock() } }() } // Stop 停止批量写入器 func (w *BatchDBWriter) Stop() { w.mu.Lock() defer w.mu.Unlock() // 刷新剩余数据 if err := w.flushBatch(); err != nil { logger.Errorf("停止时刷新批次失败: %v", err) } w.queue.Close() }