limiter.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. // Copyright 2020 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package utils
  17. import (
  18. "sort"
  19. "sync"
  20. "github.com/ethereum/go-ethereum/p2p/enode"
  21. )
  22. const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
  23. // Limiter protects a network request serving mechanism from denial-of-service attacks.
  24. // It limits the total amount of resources used for serving requests while ensuring that
  25. // the most valuable connections always have a reasonable chance of being served.
  26. type Limiter struct {
  27. lock sync.Mutex
  28. cond *sync.Cond
  29. quit bool
  30. nodes map[enode.ID]*nodeQueue
  31. addresses map[string]*addressGroup
  32. addressSelect, valueSelect *WeightedRandomSelect
  33. maxValue float64
  34. maxCost, sumCost, sumCostLimit uint
  35. selectAddressNext bool
  36. }
  37. // nodeQueue represents queued requests coming from a single node ID
  38. type nodeQueue struct {
  39. queue []request // always nil if penaltyCost != 0
  40. id enode.ID
  41. address string
  42. value float64
  43. flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
  44. sumCost uint // summed cost of requests queued by the node
  45. penaltyCost uint // cumulative cost of dropped requests since last processed request
  46. groupIndex int
  47. }
  48. // addressGroup is a group of node IDs that have sent their last requests from the same
  49. // network address
  50. type addressGroup struct {
  51. nodes []*nodeQueue
  52. nodeSelect *WeightedRandomSelect
  53. sumFlatWeight, groupWeight uint64
  54. }
  55. // request represents an incoming request scheduled for processing
  56. type request struct {
  57. process chan chan struct{}
  58. cost uint
  59. }
  60. // flatWeight distributes weights equally between each active network address
  61. func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
  62. // add adds the node queue to the address group. It is the caller's responsibility to
  63. // add the address group to the address map and the address selector if it wasn't
  64. // there before.
  65. func (ag *addressGroup) add(nq *nodeQueue) {
  66. if nq.groupIndex != -1 {
  67. panic("added node queue is already in an address group")
  68. }
  69. l := len(ag.nodes)
  70. nq.groupIndex = l
  71. ag.nodes = append(ag.nodes, nq)
  72. ag.sumFlatWeight += nq.flatWeight
  73. ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
  74. ag.nodeSelect.Update(ag.nodes[l])
  75. }
  76. // update updates the selection weight of the node queue inside the address group.
  77. // It is the caller's responsibility to update the group's selection weight in the
  78. // address selector.
  79. func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
  80. if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
  81. panic("updated node queue is not in this address group")
  82. }
  83. ag.sumFlatWeight += weight - nq.flatWeight
  84. nq.flatWeight = weight
  85. ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
  86. ag.nodeSelect.Update(nq)
  87. }
  88. // remove removes the node queue from the address group. It is the caller's responsibility
  89. // to remove the address group from the address map if it is empty.
  90. func (ag *addressGroup) remove(nq *nodeQueue) {
  91. if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
  92. panic("removed node queue is not in this address group")
  93. }
  94. l := len(ag.nodes) - 1
  95. if nq.groupIndex != l {
  96. ag.nodes[nq.groupIndex] = ag.nodes[l]
  97. ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
  98. }
  99. nq.groupIndex = -1
  100. ag.nodes = ag.nodes[:l]
  101. ag.sumFlatWeight -= nq.flatWeight
  102. if l >= 1 {
  103. ag.groupWeight = ag.sumFlatWeight / uint64(l)
  104. } else {
  105. ag.groupWeight = 0
  106. }
  107. ag.nodeSelect.Remove(nq)
  108. }
  109. // choose selects one of the node queues belonging to the address group
  110. func (ag *addressGroup) choose() *nodeQueue {
  111. return ag.nodeSelect.Choose().(*nodeQueue)
  112. }
  113. // NewLimiter creates a new Limiter
  114. func NewLimiter(sumCostLimit uint) *Limiter {
  115. l := &Limiter{
  116. addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
  117. valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
  118. nodes: make(map[enode.ID]*nodeQueue),
  119. addresses: make(map[string]*addressGroup),
  120. sumCostLimit: sumCostLimit,
  121. }
  122. l.cond = sync.NewCond(&l.lock)
  123. go l.processLoop()
  124. return l
  125. }
  126. // selectionWeights calculates the selection weights of a node for both the address and
  127. // the value selector. The selection weight depends on the next request cost or the
  128. // summed cost of recently dropped requests.
  129. func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
  130. if value > l.maxValue {
  131. l.maxValue = value
  132. }
  133. if value > 0 {
  134. // normalize value to <= 1
  135. value /= l.maxValue
  136. }
  137. if reqCost > l.maxCost {
  138. l.maxCost = reqCost
  139. }
  140. relCost := float64(reqCost) / float64(l.maxCost)
  141. var f float64
  142. if relCost <= 0.001 {
  143. f = 1
  144. } else {
  145. f = 0.001 / relCost
  146. }
  147. f *= maxSelectionWeight
  148. flatWeight, valueWeight = uint64(f), uint64(f*value)
  149. if flatWeight == 0 {
  150. flatWeight = 1
  151. }
  152. return
  153. }
  154. // Add adds a new request to the node queue belonging to the given id. Value belongs
  155. // to the requesting node. A higher value gives the request a higher chance of being
  156. // served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
  157. // of the serving cost of the request. A lower cost also gives the request a
  158. // better chance.
  159. func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
  160. l.lock.Lock()
  161. defer l.lock.Unlock()
  162. process := make(chan chan struct{}, 1)
  163. if l.quit {
  164. close(process)
  165. return process
  166. }
  167. if reqCost == 0 {
  168. reqCost = 1
  169. }
  170. if nq, ok := l.nodes[id]; ok {
  171. if nq.queue != nil {
  172. nq.queue = append(nq.queue, request{process, reqCost})
  173. nq.sumCost += reqCost
  174. nq.value = value
  175. if address != nq.address {
  176. // known id sending request from a new address, move to different address group
  177. l.removeFromGroup(nq)
  178. l.addToGroup(nq, address)
  179. }
  180. } else {
  181. // already waiting on a penalty, just add to the penalty cost and drop the request
  182. nq.penaltyCost += reqCost
  183. l.update(nq)
  184. close(process)
  185. return process
  186. }
  187. } else {
  188. nq := &nodeQueue{
  189. queue: []request{{process, reqCost}},
  190. id: id,
  191. value: value,
  192. sumCost: reqCost,
  193. groupIndex: -1,
  194. }
  195. nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
  196. if len(l.nodes) == 0 {
  197. l.cond.Signal()
  198. }
  199. l.nodes[id] = nq
  200. if nq.valueWeight != 0 {
  201. l.valueSelect.Update(nq)
  202. }
  203. l.addToGroup(nq, address)
  204. }
  205. l.sumCost += reqCost
  206. if l.sumCost > l.sumCostLimit {
  207. l.dropRequests()
  208. }
  209. return process
  210. }
  211. // update updates the selection weights of the node queue
  212. func (l *Limiter) update(nq *nodeQueue) {
  213. var cost uint
  214. if nq.queue != nil {
  215. cost = nq.queue[0].cost
  216. } else {
  217. cost = nq.penaltyCost
  218. }
  219. flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
  220. ag := l.addresses[nq.address]
  221. ag.update(nq, flatWeight)
  222. l.addressSelect.Update(ag)
  223. nq.valueWeight = valueWeight
  224. l.valueSelect.Update(nq)
  225. }
  226. // addToGroup adds the node queue to the given address group. The group is created if
  227. // it does not exist yet.
  228. func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
  229. nq.address = address
  230. ag := l.addresses[address]
  231. if ag == nil {
  232. ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
  233. l.addresses[address] = ag
  234. }
  235. ag.add(nq)
  236. l.addressSelect.Update(ag)
  237. }
  238. // removeFromGroup removes the node queue from its address group
  239. func (l *Limiter) removeFromGroup(nq *nodeQueue) {
  240. ag := l.addresses[nq.address]
  241. ag.remove(nq)
  242. if len(ag.nodes) == 0 {
  243. delete(l.addresses, nq.address)
  244. }
  245. l.addressSelect.Update(ag)
  246. }
  247. // remove removes the node queue from its address group, the nodes map and the value
  248. // selector
  249. func (l *Limiter) remove(nq *nodeQueue) {
  250. l.removeFromGroup(nq)
  251. if nq.valueWeight != 0 {
  252. l.valueSelect.Remove(nq)
  253. }
  254. delete(l.nodes, nq.id)
  255. }
  256. // choose selects the next node queue to process.
  257. func (l *Limiter) choose() *nodeQueue {
  258. if l.valueSelect.IsEmpty() || l.selectAddressNext {
  259. if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
  260. l.selectAddressNext = false
  261. return ag.choose()
  262. }
  263. }
  264. nq, _ := l.valueSelect.Choose().(*nodeQueue)
  265. l.selectAddressNext = true
  266. return nq
  267. }
  268. // processLoop processes requests sequentially
  269. func (l *Limiter) processLoop() {
  270. l.lock.Lock()
  271. defer l.lock.Unlock()
  272. for {
  273. if l.quit {
  274. for _, nq := range l.nodes {
  275. for _, request := range nq.queue {
  276. close(request.process)
  277. }
  278. }
  279. return
  280. }
  281. nq := l.choose()
  282. if nq == nil {
  283. l.cond.Wait()
  284. continue
  285. }
  286. if nq.queue != nil {
  287. request := nq.queue[0]
  288. nq.queue = nq.queue[1:]
  289. nq.sumCost -= request.cost
  290. l.sumCost -= request.cost
  291. l.lock.Unlock()
  292. ch := make(chan struct{})
  293. request.process <- ch
  294. <-ch
  295. l.lock.Lock()
  296. if len(nq.queue) > 0 {
  297. l.update(nq)
  298. } else {
  299. l.remove(nq)
  300. }
  301. } else {
  302. // penalized queue removed, next request will be added to a clean queue
  303. l.remove(nq)
  304. }
  305. }
  306. }
  307. // Stop stops the processing loop. All queued and future requests are rejected.
  308. func (l *Limiter) Stop() {
  309. l.lock.Lock()
  310. defer l.lock.Unlock()
  311. l.quit = true
  312. l.cond.Signal()
  313. }
  314. type (
  315. dropList []dropListItem
  316. dropListItem struct {
  317. nq *nodeQueue
  318. priority float64
  319. }
  320. )
  321. func (l dropList) Len() int {
  322. return len(l)
  323. }
  324. func (l dropList) Less(i, j int) bool {
  325. return l[i].priority < l[j].priority
  326. }
  327. func (l dropList) Swap(i, j int) {
  328. l[i], l[j] = l[j], l[i]
  329. }
  330. // dropRequests selects the nodes with the highest queued request cost to selection
  331. // weight ratio and drops their queued request. The empty node queues stay in the
  332. // selectors with a low selection weight in order to penalize these nodes.
  333. func (l *Limiter) dropRequests() {
  334. var (
  335. sumValue float64
  336. list dropList
  337. )
  338. for _, nq := range l.nodes {
  339. sumValue += nq.value
  340. }
  341. for _, nq := range l.nodes {
  342. if nq.sumCost == 0 {
  343. continue
  344. }
  345. w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
  346. if sumValue > 0 {
  347. w += nq.value / sumValue
  348. }
  349. list = append(list, dropListItem{
  350. nq: nq,
  351. priority: w / float64(nq.sumCost),
  352. })
  353. }
  354. sort.Sort(list)
  355. for _, item := range list {
  356. for _, request := range item.nq.queue {
  357. close(request.process)
  358. }
  359. // make the queue penalized; no more requests are accepted until the node is
  360. // selected based on the penalty cost which is the cumulative cost of all dropped
  361. // requests. This ensures that sending excess requests is always penalized
  362. // and incentivizes the sender to stop for a while if no replies are received.
  363. item.nq.queue = nil
  364. item.nq.penaltyCost = item.nq.sumCost
  365. l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
  366. item.nq.sumCost = 0
  367. l.update(item.nq)
  368. if l.sumCost <= l.sumCostLimit/2 {
  369. return
  370. }
  371. }
  372. }