rand.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package otherutils
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. mathrand "math/rand"
  8. "sync"
  9. )
  10. var (
  11. ErrInvalidWeight = errors.New("invalid weight")
  12. ErrNoItems = errors.New("no items available")
  13. ErrInvalidRange = errors.New("invalid range")
  14. ErrInvalidCount = errors.New("invalid count")
  15. ErrOverflow = errors.New("numeric overflow")
  16. ErrUnsupportedType = errors.New("unsupported number type")
  17. )
  18. // safeRand 提供线程安全的随机数生成器
  19. type safeRand struct {
  20. mu sync.Mutex
  21. rng *mathrand.Rand
  22. }
  23. var globalRand = &safeRand{}
  24. func init() {
  25. var b [8]byte
  26. if _, err := rand.Read(b[:]); err != nil {
  27. panic("failed to initialize random seed: " + err.Error())
  28. }
  29. seed := binary.LittleEndian.Uint64(b[:])
  30. globalRand.rng = mathrand.New(mathrand.NewSource(int64(seed)))
  31. }
  32. func (s *safeRand) Int63n(n int64) int64 {
  33. s.mu.Lock()
  34. defer s.mu.Unlock()
  35. return s.rng.Int63n(n)
  36. }
  37. func (s *safeRand) Float64() float64 {
  38. s.mu.Lock()
  39. defer s.mu.Unlock()
  40. return s.rng.Float64()
  41. }
  42. func (s *safeRand) Intn(n int) int {
  43. s.mu.Lock()
  44. defer s.mu.Unlock()
  45. return s.rng.Intn(n)
  46. }
  47. // Number 定义了支持的数值类型
  48. type Number interface {
  49. ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64
  50. }
  51. // safeNumberConversion 安全的数值类型转换
  52. func safeNumberConversion[T Number](val int64) (T, error) {
  53. var zero T
  54. switch any(zero).(type) {
  55. case int8:
  56. if val > math.MaxInt8 || val < math.MinInt8 {
  57. return zero, ErrOverflow
  58. }
  59. case int16:
  60. if val > math.MaxInt16 || val < math.MinInt16 {
  61. return zero, ErrOverflow
  62. }
  63. case int32:
  64. if val > math.MaxInt32 || val < math.MinInt32 {
  65. return zero, ErrOverflow
  66. }
  67. case uint8:
  68. if val > math.MaxUint8 || val < 0 {
  69. return zero, ErrOverflow
  70. }
  71. case uint16:
  72. if val > math.MaxUint16 || val < 0 {
  73. return zero, ErrOverflow
  74. }
  75. case uint32:
  76. if val > math.MaxUint32 || val < 0 {
  77. return zero, ErrOverflow
  78. }
  79. case uint64:
  80. if val < 0 {
  81. return zero, ErrOverflow
  82. }
  83. }
  84. return T(val), nil
  85. }
  86. // RandInterval 在指定范围内随机生成一个数[min, max]
  87. func RandInterval[T Number](min, max T) (T, error) {
  88. if min > max {
  89. min, max = max, min
  90. }
  91. switch any(min).(type) {
  92. case int, int8, int16, int32, int64:
  93. diff := int64(max) - int64(min)
  94. randVal := globalRand.Int63n(diff + 1)
  95. return safeNumberConversion[T](randVal + int64(min))
  96. case uint, uint8, uint16, uint32, uint64:
  97. diff := uint64(max) - uint64(min)
  98. randVal := uint64(globalRand.Int63n(int64(diff + 1)))
  99. return T(randVal + uint64(min)), nil
  100. case float32, float64:
  101. const resolution = 1000000
  102. scaled := globalRand.Int63n(resolution)
  103. result := T(float64(scaled)/resolution*float64(max-min) + float64(min))
  104. return result, nil
  105. default:
  106. return min, ErrUnsupportedType
  107. }
  108. }
  109. // reservoirSampling 使用蓄水池抽样算法生成随机数
  110. func reservoirSampling[T Number](min, max T, n int) ([]T, error) {
  111. result := make([]T, n)
  112. // 填充前n个数
  113. for i := 0; i < n; i++ {
  114. result[i] = min + T(i)
  115. }
  116. // 蓄水池抽样
  117. for i := n; i < int(max-min+1); i++ {
  118. j := globalRand.Intn(i + 1)
  119. if j < n {
  120. result[j] = min + T(i)
  121. }
  122. }
  123. return result, nil
  124. }
  125. // RandIntervalN 随机生成n个范围在[min, max]之间的唯一整数
  126. func RandIntervalN[T Number](min, max T, n int) ([]T, error) {
  127. if n <= 0 {
  128. return nil, ErrInvalidCount
  129. }
  130. if min > max {
  131. min, max = max, min
  132. }
  133. rangeSize := max - min + 1
  134. if n > int(rangeSize) {
  135. n = int(rangeSize)
  136. }
  137. result := make([]T, n)
  138. usedNumbers := make(map[T]T)
  139. currentRange := T(rangeSize)
  140. for i := 0; i < n; i++ {
  141. // 从当前范围中随机选择一个数
  142. randomValue := min + T(globalRand.Int63n(int64(currentRange)))
  143. // 如果这个数已经被使用过,使用它的映射值
  144. if replacementValue, exists := usedNumbers[randomValue]; exists {
  145. result[i] = replacementValue
  146. } else {
  147. result[i] = randomValue
  148. }
  149. // 当前范围的最后一个数
  150. lastValue := max - T(i)
  151. if randomValue != lastValue {
  152. // 更新映射关系
  153. if replacementValue, exists := usedNumbers[lastValue]; exists {
  154. usedNumbers[randomValue] = replacementValue
  155. } else {
  156. usedNumbers[randomValue] = lastValue
  157. }
  158. }
  159. currentRange--
  160. }
  161. return result, nil
  162. }
  163. // WeightedItem 定义了带权重的项接口
  164. type WeightedItem interface {
  165. GetWeight() int64
  166. }
  167. // WeightedTable 实现了基于权重的随机选择
  168. type WeightedTable[T any] struct {
  169. items []T
  170. weights []int64
  171. totalWeight int64
  172. // getWeight func(T) int64
  173. mu sync.RWMutex
  174. }
  175. // validateWeights 验证权重的有效性
  176. func (wt *WeightedTable[T]) validateWeights() error {
  177. if wt.totalWeight <= 0 {
  178. return ErrInvalidWeight
  179. }
  180. return nil
  181. }
  182. // NewWeightedTable 创建新的WeightedTable实例
  183. func NewWeightedTable[T any](items []T, getWeight func(T) int64) (*WeightedTable[T], error) {
  184. if len(items) == 0 {
  185. return nil, ErrNoItems
  186. }
  187. if getWeight == nil {
  188. return nil, errors.New("getWeight function cannot be nil")
  189. }
  190. weights := make([]int64, len(items))
  191. var totalWeight int64
  192. for i, item := range items {
  193. weight := getWeight(item)
  194. if weight < 0 {
  195. return nil, ErrInvalidWeight
  196. }
  197. weights[i] = weight
  198. if weight > math.MaxInt64-totalWeight { // 检查是否会溢出
  199. return nil, ErrInvalidWeight
  200. }
  201. totalWeight += weight
  202. }
  203. if totalWeight <= 0 {
  204. return nil, ErrInvalidWeight
  205. }
  206. return &WeightedTable[T]{
  207. weights: weights,
  208. totalWeight: totalWeight,
  209. items: items,
  210. // getWeight: getWeight,
  211. }, nil
  212. }
  213. // // UpdateWeights 更新所有权重
  214. // func (wt *WeightedTable[T]) UpdateWeights() error {
  215. // wt.mu.Lock()
  216. // defer wt.mu.Unlock()
  217. // var totalWeight int64
  218. // for i, item := range wt.items {
  219. // weight := wt.getWeight(item)
  220. // if weight < 0 {
  221. // return ErrInvalidWeight
  222. // }
  223. // wt.weights[i] = weight
  224. // if weight > math.MaxInt64-totalWeight { // 检查是否会溢出
  225. // return ErrInvalidWeight
  226. // }
  227. // totalWeight += weight
  228. // }
  229. // if totalWeight <= 0 {
  230. // return ErrInvalidWeight
  231. // }
  232. // wt.totalWeight = totalWeight
  233. // return nil
  234. // }
  235. // RandOneItem 随机选择一个项
  236. func (wt *WeightedTable[T]) RandOneItem() (T, error) {
  237. wt.mu.RLock()
  238. defer wt.mu.RUnlock()
  239. var zero T
  240. if err := wt.validateWeights(); err != nil {
  241. return zero, err
  242. }
  243. randomNumber := globalRand.Int63n(wt.totalWeight)
  244. var cumulativeWeight int64
  245. for i, weight := range wt.weights {
  246. cumulativeWeight += weight
  247. if randomNumber < cumulativeWeight {
  248. return wt.items[i], nil
  249. }
  250. }
  251. return zero, errors.New("failed to select item")
  252. }
  253. // RandChoiceByWeighted 根据权重选择一个元素
  254. func RandChoiceByWeighted(items [][]int64) (int64, error) {
  255. if len(items) == 0 {
  256. return 0, ErrNoItems
  257. }
  258. var totalWeight int64
  259. for _, item := range items {
  260. if len(item) != 2 {
  261. return 0, errors.New("invalid item format")
  262. }
  263. if item[1] < 0 {
  264. return 0, ErrInvalidWeight
  265. }
  266. if item[1] > math.MaxInt64-totalWeight { // 检查是否会溢出
  267. return 0, ErrInvalidWeight
  268. }
  269. totalWeight += item[1]
  270. }
  271. if totalWeight <= 0 {
  272. return 0, ErrInvalidWeight
  273. }
  274. randomNumber := globalRand.Int63n(totalWeight)
  275. var cumulativeWeight int64
  276. for _, item := range items {
  277. cumulativeWeight += item[1]
  278. if randomNumber < cumulativeWeight {
  279. return item[0], nil
  280. }
  281. }
  282. return 0, errors.New("failed to select item")
  283. }