weighter.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package weightrand
  2. import (
  3. "math/rand"
  4. "sort"
  5. )
  6. type WeightItem[T any] struct {
  7. Item T
  8. Weight int64
  9. AccSum int64
  10. }
  11. type RandTable[T any] struct {
  12. allItems []*WeightItem[T]
  13. sum int64
  14. needCalc bool
  15. hitLimit map[int]int
  16. }
  17. func NewRandTable[T any]() *RandTable[T] {
  18. wr := new(RandTable[T])
  19. wr.hitLimit = make(map[int]int)
  20. return wr
  21. }
  22. // 增加随机项,
  23. func (rt *RandTable[T]) AddItem(item T, weight int64) {
  24. if weight < 0 {
  25. weight = 0
  26. }
  27. item1 := new(WeightItem[T])
  28. item1.Item = item
  29. item1.Weight = weight
  30. rt.allItems = append(rt.allItems, item1)
  31. rt.needCalc = true
  32. }
  33. // 添加需要限制命中次数的项
  34. func (rt *RandTable[T]) AddLimitItem(item T, weight int64, limit int) {
  35. if weight < 0 {
  36. weight = 0
  37. }
  38. item1 := new(WeightItem[T])
  39. item1.Item = item
  40. item1.Weight = weight
  41. rt.allItems = append(rt.allItems, item1)
  42. l := len(rt.allItems)
  43. rt.hitLimit[l-1] = limit
  44. rt.needCalc = true
  45. }
  46. // 计算总权重
  47. func (rt *RandTable[T]) calcSum() {
  48. rt.sum = 0
  49. for _, item := range rt.allItems {
  50. rt.sum += item.Weight
  51. item.AccSum = rt.sum
  52. }
  53. }
  54. func (rt *RandTable[T]) WeightSum() int64 {
  55. if rt.needCalc {
  56. rt.calcSum()
  57. rt.needCalc = false
  58. }
  59. return rt.sum
  60. }
  61. func (rt *RandTable[T]) Size() int {
  62. return len(rt.allItems)
  63. }
  64. // 如果没有合适项,返回该类型默认值 指针为nil,int为0
  65. func (rt *RandTable[T]) randomItem() (*WeightItem[T], int) {
  66. if rt.needCalc {
  67. rt.calcSum()
  68. rt.needCalc = false
  69. }
  70. if rt.sum == 0 {
  71. return nil, 0
  72. }
  73. rnd := rand.Int63n(rt.sum)
  74. rnd += 1
  75. cnt := len(rt.allItems)
  76. index := sort.Search(cnt, func(i int) bool { return rt.allItems[i].AccSum >= rnd })
  77. if index < cnt {
  78. return rt.allItems[index], index
  79. }
  80. return nil, 0
  81. }
  82. // 如果没有合适项,返回该类型默认值 指针为nil,int为0
  83. func (rt *RandTable[T]) GetRandomItem() (df T, ok bool) {
  84. v, index := rt.randomItem()
  85. if v != nil {
  86. df = v.Item
  87. ok = true
  88. // 命中项需要限制次数,当限制次数用完就不能命中
  89. if v2, ok := rt.hitLimit[index]; ok {
  90. v2 -= 1
  91. rt.hitLimit[index] = v2
  92. if v2 <= 0 {
  93. v.Weight = 0
  94. rt.needCalc = true
  95. }
  96. }
  97. }
  98. return
  99. }
  100. // 随机移除,并返回移除项;并非真的移除,只是把移除项的权重设为0,不会再随机到
  101. func (rt *RandTable[T]) RandomRemoveItem() (df T, ok bool) {
  102. v, _ := rt.randomItem()
  103. if v != nil {
  104. df = v.Item
  105. ok = true
  106. v.Weight = 0
  107. rt.needCalc = true
  108. }
  109. return
  110. }
  111. // 增加权重,返回第几次增加, 从0开始计数
  112. type IndexRandTable struct {
  113. allItems []int64
  114. sum int64
  115. }
  116. func NewIndexRandTable(cap int) *IndexRandTable {
  117. wr := new(IndexRandTable)
  118. wr.allItems = make([]int64, 0, cap)
  119. return wr
  120. }
  121. func (srt *IndexRandTable) AddItem(weight int64) {
  122. if weight < 0 {
  123. weight = 0
  124. }
  125. srt.sum += weight
  126. srt.allItems = append(srt.allItems, srt.sum)
  127. }
  128. // 没有随机到时,返回-1,随机到返回第几次添加,从0开始
  129. func (srt *IndexRandTable) GetRandomItem() int {
  130. if srt.sum <= 0 {
  131. return -1
  132. }
  133. rnd := rand.Int63n(srt.sum)
  134. rnd += 1
  135. cnt := len(srt.allItems)
  136. index := sort.Search(cnt, func(i int) bool { return srt.allItems[i] >= rnd })
  137. if index < cnt {
  138. return index
  139. }
  140. return -1
  141. }