table_util_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. // Copyright 2018 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package discover
  17. import (
  18. "bytes"
  19. "crypto/ecdsa"
  20. "encoding/hex"
  21. "errors"
  22. "fmt"
  23. "math/rand"
  24. "net"
  25. "sort"
  26. "sync"
  27. "github.com/ethereum/go-ethereum/crypto"
  28. "github.com/ethereum/go-ethereum/log"
  29. "github.com/ethereum/go-ethereum/p2p/enode"
  30. "github.com/ethereum/go-ethereum/p2p/enr"
  31. )
  32. var nullNode *enode.Node
  33. func init() {
  34. var r enr.Record
  35. r.Set(enr.IP{0, 0, 0, 0})
  36. nullNode = enode.SignNull(&r, enode.ID{})
  37. }
  38. func newTestTable(t transport) (*Table, *enode.DB) {
  39. db, _ := enode.OpenDB("")
  40. tab, _ := newTable(t, db, nil, log.Root())
  41. go tab.loop()
  42. return tab, db
  43. }
  44. // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
  45. func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node {
  46. var r enr.Record
  47. r.Set(enr.IP(ip))
  48. return wrapNode(enode.SignNull(&r, idAtDistance(base, ld)))
  49. }
  50. // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
  51. func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
  52. results := make([]*enode.Node, n)
  53. for i := range results {
  54. results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i)))
  55. }
  56. return results
  57. }
  58. func nodesToRecords(nodes []*enode.Node) []*enr.Record {
  59. records := make([]*enr.Record, len(nodes))
  60. for i := range nodes {
  61. records[i] = nodes[i].Record()
  62. }
  63. return records
  64. }
  65. // idAtDistance returns a random hash such that enode.LogDist(a, b) == n
  66. func idAtDistance(a enode.ID, n int) (b enode.ID) {
  67. if n == 0 {
  68. return a
  69. }
  70. // flip bit at position n, fill the rest with random bits
  71. b = a
  72. pos := len(a) - n/8 - 1
  73. bit := byte(0x01) << (byte(n%8) - 1)
  74. if bit == 0 {
  75. pos++
  76. bit = 0x80
  77. }
  78. b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
  79. for i := pos + 1; i < len(a); i++ {
  80. b[i] = byte(rand.Intn(255))
  81. }
  82. return b
  83. }
  84. func intIP(i int) net.IP {
  85. return net.IP{byte(i), 0, 2, byte(i)}
  86. }
  87. // fillBucket inserts nodes into the given bucket until it is full.
  88. func fillBucket(tab *Table, n *node) (last *node) {
  89. ld := enode.LogDist(tab.self().ID(), n.ID())
  90. b := tab.bucket(n.ID())
  91. for len(b.entries) < bucketSize {
  92. b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld)))
  93. }
  94. return b.entries[bucketSize-1]
  95. }
  96. // fillTable adds nodes the table to the end of their corresponding bucket
  97. // if the bucket is not full. The caller must not hold tab.mutex.
  98. func fillTable(tab *Table, nodes []*node) {
  99. for _, n := range nodes {
  100. tab.addSeenNode(n)
  101. }
  102. }
  103. type pingRecorder struct {
  104. mu sync.Mutex
  105. dead, pinged map[enode.ID]bool
  106. records map[enode.ID]*enode.Node
  107. n *enode.Node
  108. }
  109. func newPingRecorder() *pingRecorder {
  110. var r enr.Record
  111. r.Set(enr.IP{0, 0, 0, 0})
  112. n := enode.SignNull(&r, enode.ID{})
  113. return &pingRecorder{
  114. dead: make(map[enode.ID]bool),
  115. pinged: make(map[enode.ID]bool),
  116. records: make(map[enode.ID]*enode.Node),
  117. n: n,
  118. }
  119. }
  120. // setRecord updates a node record. Future calls to ping and
  121. // requestENR will return this record.
  122. func (t *pingRecorder) updateRecord(n *enode.Node) {
  123. t.mu.Lock()
  124. defer t.mu.Unlock()
  125. t.records[n.ID()] = n
  126. }
  127. // Stubs to satisfy the transport interface.
  128. func (t *pingRecorder) Self() *enode.Node { return nullNode }
  129. func (t *pingRecorder) lookupSelf() []*enode.Node { return nil }
  130. func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
  131. // ping simulates a ping request.
  132. func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
  133. t.mu.Lock()
  134. defer t.mu.Unlock()
  135. t.pinged[n.ID()] = true
  136. if t.dead[n.ID()] {
  137. return 0, errTimeout
  138. }
  139. if t.records[n.ID()] != nil {
  140. seq = t.records[n.ID()].Seq()
  141. }
  142. return seq, nil
  143. }
  144. // requestENR simulates an ENR request.
  145. func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
  146. t.mu.Lock()
  147. defer t.mu.Unlock()
  148. if t.dead[n.ID()] || t.records[n.ID()] == nil {
  149. return nil, errTimeout
  150. }
  151. return t.records[n.ID()], nil
  152. }
  153. func hasDuplicates(slice []*node) bool {
  154. seen := make(map[enode.ID]bool)
  155. for i, e := range slice {
  156. if e == nil {
  157. panic(fmt.Sprintf("nil *Node at %d", i))
  158. }
  159. if seen[e.ID()] {
  160. return true
  161. }
  162. seen[e.ID()] = true
  163. }
  164. return false
  165. }
  166. // checkNodesEqual checks whether the two given node lists contain the same nodes.
  167. func checkNodesEqual(got, want []*enode.Node) error {
  168. if len(got) == len(want) {
  169. for i := range got {
  170. if !nodeEqual(got[i], want[i]) {
  171. goto NotEqual
  172. }
  173. }
  174. }
  175. return nil
  176. NotEqual:
  177. output := new(bytes.Buffer)
  178. fmt.Fprintf(output, "got %d nodes:\n", len(got))
  179. for _, n := range got {
  180. fmt.Fprintf(output, " %v %v\n", n.ID(), n)
  181. }
  182. fmt.Fprintf(output, "want %d:\n", len(want))
  183. for _, n := range want {
  184. fmt.Fprintf(output, " %v %v\n", n.ID(), n)
  185. }
  186. return errors.New(output.String())
  187. }
  188. func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
  189. return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
  190. }
  191. func sortByID(nodes []*enode.Node) {
  192. sort.Slice(nodes, func(i, j int) bool {
  193. return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes())
  194. })
  195. }
  196. func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
  197. return sort.SliceIsSorted(slice, func(i, j int) bool {
  198. return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0
  199. })
  200. }
  201. // hexEncPrivkey decodes h as a private key.
  202. func hexEncPrivkey(h string) *ecdsa.PrivateKey {
  203. b, err := hex.DecodeString(h)
  204. if err != nil {
  205. panic(err)
  206. }
  207. key, err := crypto.ToECDSA(b)
  208. if err != nil {
  209. panic(err)
  210. }
  211. return key
  212. }
  213. // hexEncPubkey decodes h as a public key.
  214. func hexEncPubkey(h string) (ret encPubkey) {
  215. b, err := hex.DecodeString(h)
  216. if err != nil {
  217. panic(err)
  218. }
  219. if len(b) != len(ret) {
  220. panic("invalid length")
  221. }
  222. copy(ret[:], b)
  223. return ret
  224. }