package otherutils import ( "crypto/rand" "encoding/binary" "errors" "math" mathrand "math/rand" "sync" ) var ( ErrInvalidWeight = errors.New("invalid weight") ErrNoItems = errors.New("no items available") ErrInvalidRange = errors.New("invalid range") ErrInvalidCount = errors.New("invalid count") ErrOverflow = errors.New("numeric overflow") ErrUnsupportedType = errors.New("unsupported number type") ) // safeRand 提供线程安全的随机数生成器 type safeRand struct { mu sync.Mutex rng *mathrand.Rand } var globalRand = &safeRand{} func init() { var b [8]byte if _, err := rand.Read(b[:]); err != nil { panic("failed to initialize random seed: " + err.Error()) } seed := binary.LittleEndian.Uint64(b[:]) globalRand.rng = mathrand.New(mathrand.NewSource(int64(seed))) } func (s *safeRand) Int63n(n int64) int64 { s.mu.Lock() defer s.mu.Unlock() return s.rng.Int63n(n) } func (s *safeRand) Float64() float64 { s.mu.Lock() defer s.mu.Unlock() return s.rng.Float64() } func (s *safeRand) Intn(n int) int { s.mu.Lock() defer s.mu.Unlock() return s.rng.Intn(n) } // Number 定义了支持的数值类型 type Number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 } // safeNumberConversion 安全的数值类型转换 func safeNumberConversion[T Number](val int64) (T, error) { var zero T switch any(zero).(type) { case int8: if val > math.MaxInt8 || val < math.MinInt8 { return zero, ErrOverflow } case int16: if val > math.MaxInt16 || val < math.MinInt16 { return zero, ErrOverflow } case int32: if val > math.MaxInt32 || val < math.MinInt32 { return zero, ErrOverflow } case uint8: if val > math.MaxUint8 || val < 0 { return zero, ErrOverflow } case uint16: if val > math.MaxUint16 || val < 0 { return zero, ErrOverflow } case uint32: if val > math.MaxUint32 || val < 0 { return zero, ErrOverflow } case uint64: if val < 0 { return zero, ErrOverflow } } return T(val), nil } // RandInterval 在指定范围内随机生成一个数[min, max] func RandInterval[T Number](min, max T) (T, error) { if min > max { min, max = max, min } switch any(min).(type) { case int, int8, int16, int32, int64: diff := int64(max) - int64(min) randVal := globalRand.Int63n(diff + 1) return safeNumberConversion[T](randVal + int64(min)) case uint, uint8, uint16, uint32, uint64: diff := uint64(max) - uint64(min) randVal := uint64(globalRand.Int63n(int64(diff + 1))) return T(randVal + uint64(min)), nil case float32, float64: const resolution = 1000000 scaled := globalRand.Int63n(resolution) result := T(float64(scaled)/resolution*float64(max-min) + float64(min)) return result, nil default: return min, ErrUnsupportedType } } // reservoirSampling 使用蓄水池抽样算法生成随机数 func reservoirSampling[T Number](min, max T, n int) ([]T, error) { result := make([]T, n) // 填充前n个数 for i := 0; i < n; i++ { result[i] = min + T(i) } // 蓄水池抽样 for i := n; i < int(max-min+1); i++ { j := globalRand.Intn(i + 1) if j < n { result[j] = min + T(i) } } return result, nil } // RandIntervalN 随机生成n个范围在[min, max]之间的唯一整数 func RandIntervalN[T Number](min, max T, n int) ([]T, error) { if n <= 0 { return nil, ErrInvalidCount } if min > max { min, max = max, min } rangeSize := max - min + 1 if n > int(rangeSize) { n = int(rangeSize) } result := make([]T, n) usedNumbers := make(map[T]T) currentRange := T(rangeSize) for i := 0; i < n; i++ { // 从当前范围中随机选择一个数 randomValue := min + T(globalRand.Int63n(int64(currentRange))) // 如果这个数已经被使用过,使用它的映射值 if replacementValue, exists := usedNumbers[randomValue]; exists { result[i] = replacementValue } else { result[i] = randomValue } // 当前范围的最后一个数 lastValue := max - T(i) if randomValue != lastValue { // 更新映射关系 if replacementValue, exists := usedNumbers[lastValue]; exists { usedNumbers[randomValue] = replacementValue } else { usedNumbers[randomValue] = lastValue } } currentRange-- } return result, nil } // WeightedItem 定义了带权重的项接口 type WeightedItem interface { GetWeight() int64 } // WeightedTable 实现了基于权重的随机选择 type WeightedTable[T any] struct { items []T weights []int64 totalWeight int64 // getWeight func(T) int64 mu sync.RWMutex } // validateWeights 验证权重的有效性 func (wt *WeightedTable[T]) validateWeights() error { if wt.totalWeight <= 0 { return ErrInvalidWeight } return nil } // NewWeightedTable 创建新的WeightedTable实例 func NewWeightedTable[T any](items []T, getWeight func(T) int64) (*WeightedTable[T], error) { if len(items) == 0 { return nil, ErrNoItems } if getWeight == nil { return nil, errors.New("getWeight function cannot be nil") } weights := make([]int64, len(items)) var totalWeight int64 for i, item := range items { weight := getWeight(item) if weight < 0 { return nil, ErrInvalidWeight } weights[i] = weight if weight > math.MaxInt64-totalWeight { // 检查是否会溢出 return nil, ErrInvalidWeight } totalWeight += weight } if totalWeight <= 0 { return nil, ErrInvalidWeight } return &WeightedTable[T]{ weights: weights, totalWeight: totalWeight, items: items, // getWeight: getWeight, }, nil } // // UpdateWeights 更新所有权重 // func (wt *WeightedTable[T]) UpdateWeights() error { // wt.mu.Lock() // defer wt.mu.Unlock() // var totalWeight int64 // for i, item := range wt.items { // weight := wt.getWeight(item) // if weight < 0 { // return ErrInvalidWeight // } // wt.weights[i] = weight // if weight > math.MaxInt64-totalWeight { // 检查是否会溢出 // return ErrInvalidWeight // } // totalWeight += weight // } // if totalWeight <= 0 { // return ErrInvalidWeight // } // wt.totalWeight = totalWeight // return nil // } // RandOneItem 随机选择一个项 func (wt *WeightedTable[T]) RandOneItem() (T, error) { wt.mu.RLock() defer wt.mu.RUnlock() var zero T if err := wt.validateWeights(); err != nil { return zero, err } randomNumber := globalRand.Int63n(wt.totalWeight) var cumulativeWeight int64 for i, weight := range wt.weights { cumulativeWeight += weight if randomNumber < cumulativeWeight { return wt.items[i], nil } } return zero, errors.New("failed to select item") } // RandChoiceByWeighted 根据权重选择一个元素 func RandChoiceByWeighted(items [][]int64) (int64, error) { if len(items) == 0 { return 0, ErrNoItems } var totalWeight int64 for _, item := range items { if len(item) != 2 { return 0, errors.New("invalid item format") } if item[1] < 0 { return 0, ErrInvalidWeight } if item[1] > math.MaxInt64-totalWeight { // 检查是否会溢出 return 0, ErrInvalidWeight } totalWeight += item[1] } if totalWeight <= 0 { return 0, ErrInvalidWeight } randomNumber := globalRand.Int63n(totalWeight) var cumulativeWeight int64 for _, item := range items { cumulativeWeight += item[1] if randomNumber < cumulativeWeight { return item[0], nil } } return 0, errors.New("failed to select item") }