weighted_select.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2016 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package utils
  17. import (
  18. "math"
  19. "math/rand"
  20. "github.com/ethereum/go-ethereum/log"
  21. )
  22. type (
  23. // WeightedRandomSelect is capable of weighted random selection from a set of items
  24. WeightedRandomSelect struct {
  25. root *wrsNode
  26. idx map[WrsItem]int
  27. wfn WeightFn
  28. }
  29. WrsItem interface{}
  30. WeightFn func(interface{}) uint64
  31. )
  32. // NewWeightedRandomSelect returns a new WeightedRandomSelect structure
  33. func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect {
  34. return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn}
  35. }
  36. // Update updates an item's weight, adds it if it was non-existent or removes it if
  37. // the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
  38. func (w *WeightedRandomSelect) Update(item WrsItem) {
  39. w.setWeight(item, w.wfn(item))
  40. }
  41. // Remove removes an item from the set
  42. func (w *WeightedRandomSelect) Remove(item WrsItem) {
  43. w.setWeight(item, 0)
  44. }
  45. // IsEmpty returns true if the set is empty
  46. func (w *WeightedRandomSelect) IsEmpty() bool {
  47. return w.root.sumCost == 0
  48. }
  49. // setWeight sets an item's weight to a specific value (removes it if zero)
  50. func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
  51. if weight > math.MaxInt64-w.root.sumCost {
  52. // old weight is still included in sumCost, remove and check again
  53. w.setWeight(item, 0)
  54. if weight > math.MaxInt64-w.root.sumCost {
  55. log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight)
  56. weight = math.MaxInt64 - w.root.sumCost
  57. }
  58. }
  59. idx, ok := w.idx[item]
  60. if ok {
  61. w.root.setWeight(idx, weight)
  62. if weight == 0 {
  63. delete(w.idx, item)
  64. }
  65. } else {
  66. if weight != 0 {
  67. if w.root.itemCnt == w.root.maxItems {
  68. // add a new level
  69. newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
  70. newRoot.items[0] = w.root
  71. newRoot.weights[0] = w.root.sumCost
  72. w.root = newRoot
  73. }
  74. w.idx[item] = w.root.insert(item, weight)
  75. }
  76. }
  77. }
  78. // Choose randomly selects an item from the set, with a chance proportional to its
  79. // current weight. If the weight of the chosen element has been decreased since the
  80. // last stored value, returns it with a newWeight/oldWeight chance, otherwise just
  81. // updates its weight and selects another one
  82. func (w *WeightedRandomSelect) Choose() WrsItem {
  83. for {
  84. if w.root.sumCost == 0 {
  85. return nil
  86. }
  87. val := uint64(rand.Int63n(int64(w.root.sumCost)))
  88. choice, lastWeight := w.root.choose(val)
  89. weight := w.wfn(choice)
  90. if weight != lastWeight {
  91. w.setWeight(choice, weight)
  92. }
  93. if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight {
  94. return choice
  95. }
  96. }
  97. }
  98. const wrsBranches = 8 // max number of branches in the wrsNode tree
  99. // wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes.
  100. type wrsNode struct {
  101. items [wrsBranches]interface{}
  102. weights [wrsBranches]uint64
  103. sumCost uint64
  104. level, itemCnt, maxItems int
  105. }
  106. // insert recursively inserts a new item to the tree and returns the item index
  107. func (n *wrsNode) insert(item WrsItem, weight uint64) int {
  108. branch := 0
  109. for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) {
  110. branch++
  111. if branch == wrsBranches {
  112. panic(nil)
  113. }
  114. }
  115. n.itemCnt++
  116. n.sumCost += weight
  117. n.weights[branch] += weight
  118. if n.level == 0 {
  119. n.items[branch] = item
  120. return branch
  121. }
  122. var subNode *wrsNode
  123. if n.items[branch] == nil {
  124. subNode = &wrsNode{maxItems: n.maxItems / wrsBranches, level: n.level - 1}
  125. n.items[branch] = subNode
  126. } else {
  127. subNode = n.items[branch].(*wrsNode)
  128. }
  129. subIdx := subNode.insert(item, weight)
  130. return subNode.maxItems*branch + subIdx
  131. }
  132. // setWeight updates the weight of a certain item (which should exist) and returns
  133. // the change of the last weight value stored in the tree
  134. func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
  135. if n.level == 0 {
  136. oldWeight := n.weights[idx]
  137. n.weights[idx] = weight
  138. diff := weight - oldWeight
  139. n.sumCost += diff
  140. if weight == 0 {
  141. n.items[idx] = nil
  142. n.itemCnt--
  143. }
  144. return diff
  145. }
  146. branchItems := n.maxItems / wrsBranches
  147. branch := idx / branchItems
  148. diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight)
  149. n.weights[branch] += diff
  150. n.sumCost += diff
  151. if weight == 0 {
  152. n.itemCnt--
  153. }
  154. return diff
  155. }
  156. // choose recursively selects an item from the tree and returns it along with its weight
  157. func (n *wrsNode) choose(val uint64) (WrsItem, uint64) {
  158. for i, w := range n.weights {
  159. if val < w {
  160. if n.level == 0 {
  161. return n.items[i].(WrsItem), n.weights[i]
  162. }
  163. return n.items[i].(*wrsNode).choose(val)
  164. }
  165. val -= w
  166. }
  167. panic(nil)
  168. }