weightrand.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package weightrand
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand"
  6. "slices"
  7. )
  8. // WeightedItem 结构体,用于存储项目和其权重
  9. type WeightedItem[T comparable] struct {
  10. Item T
  11. Weight int64
  12. }
  13. // WeightedRandomSelector 结构体,用于根据权重随机选择一项
  14. type WeightedRandomSelector[T comparable] struct {
  15. Items []*WeightedItem[T]
  16. }
  17. // NewWeightedRandomSelector 创建一个新的 WeightedRandomSelector 实例
  18. func NewWeightedRandomSelector[T comparable](cnt int) *WeightedRandomSelector[T] {
  19. wrs := &WeightedRandomSelector[T]{
  20. Items: make([]*WeightedItem[T], 0, cnt),
  21. }
  22. return wrs
  23. }
  24. // AddItem 添加一个新的权重项
  25. func (wrs *WeightedRandomSelector[T]) AddItem(item T, weight int64) {
  26. if weight <= 0 {
  27. return
  28. }
  29. it := new(WeightedItem[T])
  30. it.Item = item
  31. it.Weight = weight
  32. wrs.Items = append(wrs.Items, it)
  33. }
  34. // RemoveItem 删除指定的权重项
  35. func (wrs *WeightedRandomSelector[T]) RemoveItem(item T) {
  36. wrs.Items = slices.DeleteFunc(wrs.Items, func(wi *WeightedItem[T]) bool {
  37. return wi.Item == item
  38. })
  39. }
  40. // SelectRandom 根据权重随机选择一项
  41. func (wrs *WeightedRandomSelector[T]) SelectRandom() (T, error) {
  42. var t T
  43. if len(wrs.Items) == 0 {
  44. return t, errors.New("empty items")
  45. }
  46. // 计算总权重
  47. totalWeight := int64(0)
  48. for _, item := range wrs.Items {
  49. totalWeight += item.Weight
  50. }
  51. if totalWeight <= 0 {
  52. return t, errors.New("total weight is zero")
  53. }
  54. // 生成一个 [0, totalWeight) 之间的随机数
  55. randomValue := rand.Int63n(totalWeight)
  56. // 根据随机数找到对应的项目
  57. for _, item := range wrs.Items {
  58. if randomValue < item.Weight {
  59. return item.Item, nil
  60. }
  61. randomValue -= item.Weight
  62. }
  63. // 理论上不会到达这里,因为总权重一定会匹配到某个项目
  64. return t, errors.New("unreachable code")
  65. }
  66. func test() {
  67. items := []WeightedItem[string]{
  68. {"Item1", 10},
  69. {"Item2", 20},
  70. {"Item3", 30},
  71. {"Item4", 40},
  72. }
  73. selector := NewWeightedRandomSelector[string](len(items))
  74. for i := 0; i < 10; i++ {
  75. fmt.Println(selector.SelectRandom())
  76. }
  77. }