123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- package weightrand
- import (
- "math/rand"
- "sort"
- )
- type WeightItem[T any] struct {
- Item T
- Weight int64
- AccSum int64
- }
- type RandTable[T any] struct {
- allItems []*WeightItem[T]
- sum int64
- needCalc bool
- hitLimit map[int]int
- }
- func NewRandTable[T any]() *RandTable[T] {
- wr := new(RandTable[T])
- wr.hitLimit = make(map[int]int)
- return wr
- }
- // 增加随机项,
- func (rt *RandTable[T]) AddItem(item T, weight int64) {
- if weight < 0 {
- weight = 0
- }
- item1 := new(WeightItem[T])
- item1.Item = item
- item1.Weight = weight
- rt.allItems = append(rt.allItems, item1)
- rt.needCalc = true
- }
- // 添加需要限制命中次数的项
- func (rt *RandTable[T]) AddLimitItem(item T, weight int64, limit int) {
- if weight < 0 {
- weight = 0
- }
- item1 := new(WeightItem[T])
- item1.Item = item
- item1.Weight = weight
- rt.allItems = append(rt.allItems, item1)
- l := len(rt.allItems)
- rt.hitLimit[l-1] = limit
- rt.needCalc = true
- }
- // 计算总权重
- func (rt *RandTable[T]) calcSum() {
- rt.sum = 0
- for _, item := range rt.allItems {
- rt.sum += item.Weight
- item.AccSum = rt.sum
- }
- }
- func (rt *RandTable[T]) WeightSum() int64 {
- if rt.needCalc {
- rt.calcSum()
- rt.needCalc = false
- }
- return rt.sum
- }
- func (rt *RandTable[T]) Size() int {
- return len(rt.allItems)
- }
- // 如果没有合适项,返回该类型默认值 指针为nil,int为0
- func (rt *RandTable[T]) randomItem() (*WeightItem[T], int) {
- if rt.needCalc {
- rt.calcSum()
- rt.needCalc = false
- }
- if rt.sum == 0 {
- return nil, 0
- }
- rnd := rand.Int63n(rt.sum)
- rnd += 1
- cnt := len(rt.allItems)
- index := sort.Search(cnt, func(i int) bool { return rt.allItems[i].AccSum >= rnd })
- if index < cnt {
- return rt.allItems[index], index
- }
- return nil, 0
- }
- // 如果没有合适项,返回该类型默认值 指针为nil,int为0
- func (rt *RandTable[T]) GetRandomItem() (df T, ok bool) {
- v, index := rt.randomItem()
- if v != nil {
- df = v.Item
- ok = true
- // 命中项需要限制次数,当限制次数用完就不能命中
- if v2, ok := rt.hitLimit[index]; ok {
- v2 -= 1
- rt.hitLimit[index] = v2
- if v2 <= 0 {
- v.Weight = 0
- rt.needCalc = true
- }
- }
- }
- return
- }
- // 随机移除,并返回移除项;并非真的移除,只是把移除项的权重设为0,不会再随机到
- func (rt *RandTable[T]) RandomRemoveItem() (df T, ok bool) {
- v, _ := rt.randomItem()
- if v != nil {
- df = v.Item
- ok = true
- v.Weight = 0
- rt.needCalc = true
- }
- return
- }
- // 增加权重,返回第几次增加, 从0开始计数
- type IndexRandTable struct {
- allItems []int64
- sum int64
- }
- func NewIndexRandTable(cap int) *IndexRandTable {
- wr := new(IndexRandTable)
- wr.allItems = make([]int64, 0, cap)
- return wr
- }
- func (srt *IndexRandTable) AddItem(weight int64) {
- if weight < 0 {
- weight = 0
- }
- srt.sum += weight
- srt.allItems = append(srt.allItems, srt.sum)
- }
- // 没有随机到时,返回-1,随机到返回第几次添加,从0开始
- func (srt *IndexRandTable) GetRandomItem() int {
- if srt.sum <= 0 {
- return -1
- }
- rnd := rand.Int63n(srt.sum)
- rnd += 1
- cnt := len(srt.allItems)
- index := sort.Search(cnt, func(i int) bool { return srt.allItems[i] >= rnd })
- if index < cnt {
- return index
- }
- return -1
- }
|