disk_queue.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. package queue
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/nsqio/go-diskqueue"
  8. "sparkteam-dash/pkg/logger"
  9. "sync"
  10. "time"
  11. )
  12. // DataItem 表示要存储的数据结构
  13. type DataItem struct {
  14. ID int `json:"id"`
  15. Data string `json:"data"`
  16. Timestamp time.Time `json:"timestamp"`
  17. }
  18. // BatchDBWriter 批量数据库写入器
  19. type BatchDBWriter struct {
  20. queue diskqueue.Interface
  21. db *sql.DB // 数据库连接
  22. batchSize int // 批量大小
  23. timeout time.Duration // 批量超时时间
  24. mu sync.Mutex
  25. batch []DataItem
  26. lastPush time.Time
  27. }
  28. // NewBatchDBWriter 创建新的批量写入器
  29. func NewBatchDBWriter(dataPath string, db *sql.DB, batchSize int, timeout time.Duration) *BatchDBWriter {
  30. writer := &BatchDBWriter{
  31. db: db,
  32. batchSize: batchSize,
  33. timeout: timeout,
  34. batch: make([]DataItem, 0, batchSize),
  35. }
  36. // 初始化磁盘队列
  37. writer.queue = diskqueue.New(
  38. "batch-db-writer", // 名称
  39. dataPath, // 数据路径
  40. 1024*1024, // 最大文件大小 (1MB)
  41. 4, // 最小消息大小
  42. 1<<20, // 最大消息大小 (1MB)
  43. 1000, // 同步间隔
  44. 10*time.Second,
  45. writer.Log, // 日志接口
  46. )
  47. return writer
  48. }
  49. // Write 写入数据到队列
  50. func (w *BatchDBWriter) Write(data DataItem) error {
  51. jsonData, err := json.Marshal(data)
  52. if err != nil {
  53. return fmt.Errorf("marshal data error: %v", err)
  54. }
  55. return w.queue.Put(jsonData)
  56. }
  57. // HandleMessage 处理从队列中读取的消息(实现 diskqueue.Logger 接口)
  58. func (w *BatchDBWriter) HandleMessage(message []byte) error {
  59. var data DataItem
  60. if err := json.Unmarshal(message, &data); err != nil {
  61. logger.Errorf("unmarshal message error: %v", err)
  62. return nil // 跳过错误消息
  63. }
  64. w.mu.Lock()
  65. defer w.mu.Unlock()
  66. w.batch = append(w.batch, data)
  67. w.lastPush = time.Now()
  68. // 如果达到批量大小或超时,则写入数据库
  69. if len(w.batch) >= w.batchSize || time.Since(w.lastPush) > w.timeout {
  70. if err := w.flushBatch(); err != nil {
  71. return err
  72. }
  73. }
  74. return nil
  75. }
  76. // flushBatch 将当前批次数据写入数据库
  77. func (w *BatchDBWriter) flushBatch() error {
  78. if len(w.batch) == 0 {
  79. return nil
  80. }
  81. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  82. defer cancel()
  83. // 开始数据库事务
  84. tx, err := w.db.BeginTx(ctx, nil)
  85. if err != nil {
  86. return fmt.Errorf("begin transaction error: %v", err)
  87. }
  88. defer tx.Rollback()
  89. // 准备批量插入语句(根据实际数据库调整)
  90. stmt, err := tx.PrepareContext(ctx,
  91. "INSERT INTO data_table (id, data, timestamp) VALUES ($1, $2, $3)")
  92. if err != nil {
  93. return fmt.Errorf("prepare statement error: %v", err)
  94. }
  95. defer stmt.Close()
  96. // 执行批量插入
  97. for _, item := range w.batch {
  98. _, err := stmt.ExecContext(ctx, item.ID, item.Data, item.Timestamp)
  99. if err != nil {
  100. return fmt.Errorf("execute statement error: %v", err)
  101. }
  102. }
  103. // 提交事务
  104. if err := tx.Commit(); err != nil {
  105. return fmt.Errorf("commit transaction error: %v", err)
  106. }
  107. logger.Infof("成功写入 %d 条数据到数据库", len(w.batch))
  108. w.batch = w.batch[:0] // 清空批次
  109. w.lastPush = time.Now()
  110. return nil
  111. }
  112. // Log 实现 diskqueue.Logger 接口
  113. func (w *BatchDBWriter) Log(level diskqueue.LogLevel, msg string, args ...interface{}) {
  114. logger.Errorf(fmt.Sprintf("[%s] %s", level, msg), args...)
  115. }
  116. // Start 启动批量写入器
  117. func (w *BatchDBWriter) Start() {
  118. // 启动定时刷新(防止小批量数据长时间不写入)
  119. go func() {
  120. ticker := time.NewTicker(w.timeout)
  121. defer ticker.Stop()
  122. for range ticker.C {
  123. w.mu.Lock()
  124. if len(w.batch) > 0 && time.Since(w.lastPush) > w.timeout {
  125. if err := w.flushBatch(); err != nil {
  126. logger.Errorf("定时刷新批次失败: %v", err)
  127. }
  128. }
  129. w.mu.Unlock()
  130. }
  131. }()
  132. }
  133. // Stop 停止批量写入器
  134. func (w *BatchDBWriter) Stop() {
  135. w.mu.Lock()
  136. defer w.mu.Unlock()
  137. // 刷新剩余数据
  138. if err := w.flushBatch(); err != nil {
  139. logger.Errorf("停止时刷新批次失败: %v", err)
  140. }
  141. w.queue.Close()
  142. }