123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- package otherutils
- import (
- "crypto/rand"
- "encoding/binary"
- "errors"
- "math"
- mathrand "math/rand"
- "sync"
- )
- var (
- ErrInvalidWeight = errors.New("invalid weight")
- ErrNoItems = errors.New("no items available")
- ErrInvalidRange = errors.New("invalid range")
- ErrInvalidCount = errors.New("invalid count")
- ErrOverflow = errors.New("numeric overflow")
- ErrUnsupportedType = errors.New("unsupported number type")
- )
- // safeRand 提供线程安全的随机数生成器
- type safeRand struct {
- mu sync.Mutex
- rng *mathrand.Rand
- }
- var globalRand = &safeRand{}
- func init() {
- var b [8]byte
- if _, err := rand.Read(b[:]); err != nil {
- panic("failed to initialize random seed: " + err.Error())
- }
- seed := binary.LittleEndian.Uint64(b[:])
- globalRand.rng = mathrand.New(mathrand.NewSource(int64(seed)))
- }
- func (s *safeRand) Int63n(n int64) int64 {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.rng.Int63n(n)
- }
- func (s *safeRand) Float64() float64 {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.rng.Float64()
- }
- func (s *safeRand) Intn(n int) int {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.rng.Intn(n)
- }
- // Number 定义了支持的数值类型
- type Number interface {
- ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64
- }
- // safeNumberConversion 安全的数值类型转换
- func safeNumberConversion[T Number](val int64) (T, error) {
- var zero T
- switch any(zero).(type) {
- case int8:
- if val > math.MaxInt8 || val < math.MinInt8 {
- return zero, ErrOverflow
- }
- case int16:
- if val > math.MaxInt16 || val < math.MinInt16 {
- return zero, ErrOverflow
- }
- case int32:
- if val > math.MaxInt32 || val < math.MinInt32 {
- return zero, ErrOverflow
- }
- case uint8:
- if val > math.MaxUint8 || val < 0 {
- return zero, ErrOverflow
- }
- case uint16:
- if val > math.MaxUint16 || val < 0 {
- return zero, ErrOverflow
- }
- case uint32:
- if val > math.MaxUint32 || val < 0 {
- return zero, ErrOverflow
- }
- case uint64:
- if val < 0 {
- return zero, ErrOverflow
- }
- }
- return T(val), nil
- }
- // RandInterval 在指定范围内随机生成一个数[min, max]
- func RandInterval[T Number](min, max T) (T, error) {
- if min > max {
- min, max = max, min
- }
- switch any(min).(type) {
- case int, int8, int16, int32, int64:
- diff := int64(max) - int64(min)
- randVal := globalRand.Int63n(diff + 1)
- return safeNumberConversion[T](randVal + int64(min))
- case uint, uint8, uint16, uint32, uint64:
- diff := uint64(max) - uint64(min)
- randVal := uint64(globalRand.Int63n(int64(diff + 1)))
- return T(randVal + uint64(min)), nil
- case float32, float64:
- const resolution = 1000000
- scaled := globalRand.Int63n(resolution)
- result := T(float64(scaled)/resolution*float64(max-min) + float64(min))
- return result, nil
- default:
- return min, ErrUnsupportedType
- }
- }
- // reservoirSampling 使用蓄水池抽样算法生成随机数
- func reservoirSampling[T Number](min, max T, n int) ([]T, error) {
- result := make([]T, n)
- // 填充前n个数
- for i := 0; i < n; i++ {
- result[i] = min + T(i)
- }
- // 蓄水池抽样
- for i := n; i < int(max-min+1); i++ {
- j := globalRand.Intn(i + 1)
- if j < n {
- result[j] = min + T(i)
- }
- }
- return result, nil
- }
- // RandIntervalN 随机生成n个范围在[min, max]之间的唯一整数
- func RandIntervalN[T Number](min, max T, n int) ([]T, error) {
- if n <= 0 {
- return nil, ErrInvalidCount
- }
- if min > max {
- min, max = max, min
- }
- rangeSize := max - min + 1
- if n > int(rangeSize) {
- n = int(rangeSize)
- }
- result := make([]T, n)
- usedNumbers := make(map[T]T)
- currentRange := T(rangeSize)
- for i := 0; i < n; i++ {
- // 从当前范围中随机选择一个数
- randomValue := min + T(globalRand.Int63n(int64(currentRange)))
- // 如果这个数已经被使用过,使用它的映射值
- if replacementValue, exists := usedNumbers[randomValue]; exists {
- result[i] = replacementValue
- } else {
- result[i] = randomValue
- }
- // 当前范围的最后一个数
- lastValue := max - T(i)
- if randomValue != lastValue {
- // 更新映射关系
- if replacementValue, exists := usedNumbers[lastValue]; exists {
- usedNumbers[randomValue] = replacementValue
- } else {
- usedNumbers[randomValue] = lastValue
- }
- }
- currentRange--
- }
- return result, nil
- }
- // WeightedItem 定义了带权重的项接口
- type WeightedItem interface {
- GetWeight() int64
- }
- // WeightedTable 实现了基于权重的随机选择
- type WeightedTable[T any] struct {
- items []T
- weights []int64
- totalWeight int64
- // getWeight func(T) int64
- mu sync.RWMutex
- }
- // validateWeights 验证权重的有效性
- func (wt *WeightedTable[T]) validateWeights() error {
- if wt.totalWeight <= 0 {
- return ErrInvalidWeight
- }
- return nil
- }
- // NewWeightedTable 创建新的WeightedTable实例
- func NewWeightedTable[T any](items []T, getWeight func(T) int64) (*WeightedTable[T], error) {
- if len(items) == 0 {
- return nil, ErrNoItems
- }
- if getWeight == nil {
- return nil, errors.New("getWeight function cannot be nil")
- }
- weights := make([]int64, len(items))
- var totalWeight int64
- for i, item := range items {
- weight := getWeight(item)
- if weight < 0 {
- return nil, ErrInvalidWeight
- }
- weights[i] = weight
- if weight > math.MaxInt64-totalWeight { // 检查是否会溢出
- return nil, ErrInvalidWeight
- }
- totalWeight += weight
- }
- if totalWeight <= 0 {
- return nil, ErrInvalidWeight
- }
- return &WeightedTable[T]{
- weights: weights,
- totalWeight: totalWeight,
- items: items,
- // getWeight: getWeight,
- }, nil
- }
- // // UpdateWeights 更新所有权重
- // func (wt *WeightedTable[T]) UpdateWeights() error {
- // wt.mu.Lock()
- // defer wt.mu.Unlock()
- // var totalWeight int64
- // for i, item := range wt.items {
- // weight := wt.getWeight(item)
- // if weight < 0 {
- // return ErrInvalidWeight
- // }
- // wt.weights[i] = weight
- // if weight > math.MaxInt64-totalWeight { // 检查是否会溢出
- // return ErrInvalidWeight
- // }
- // totalWeight += weight
- // }
- // if totalWeight <= 0 {
- // return ErrInvalidWeight
- // }
- // wt.totalWeight = totalWeight
- // return nil
- // }
- // RandOneItem 随机选择一个项
- func (wt *WeightedTable[T]) RandOneItem() (T, error) {
- wt.mu.RLock()
- defer wt.mu.RUnlock()
- var zero T
- if err := wt.validateWeights(); err != nil {
- return zero, err
- }
- randomNumber := globalRand.Int63n(wt.totalWeight)
- var cumulativeWeight int64
- for i, weight := range wt.weights {
- cumulativeWeight += weight
- if randomNumber < cumulativeWeight {
- return wt.items[i], nil
- }
- }
- return zero, errors.New("failed to select item")
- }
- // RandChoiceByWeighted 根据权重选择一个元素
- func RandChoiceByWeighted(items [][]int64) (int64, error) {
- if len(items) == 0 {
- return 0, ErrNoItems
- }
- var totalWeight int64
- for _, item := range items {
- if len(item) != 2 {
- return 0, errors.New("invalid item format")
- }
- if item[1] < 0 {
- return 0, ErrInvalidWeight
- }
- if item[1] > math.MaxInt64-totalWeight { // 检查是否会溢出
- return 0, ErrInvalidWeight
- }
- totalWeight += item[1]
- }
- if totalWeight <= 0 {
- return 0, ErrInvalidWeight
- }
- randomNumber := globalRand.Int63n(totalWeight)
- var cumulativeWeight int64
- for _, item := range items {
- cumulativeWeight += item[1]
- if randomNumber < cumulativeWeight {
- return item[0], nil
- }
- }
- return 0, errors.New("failed to select item")
- }
|