framework.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. // Copyright 2020 The go-ethereum Authors
  2. // This file is part of go-ethereum.
  3. //
  4. // go-ethereum is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // go-ethereum is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU General Public License
  15. // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
  16. package v5test
  17. import (
  18. "bytes"
  19. "crypto/ecdsa"
  20. "encoding/binary"
  21. "fmt"
  22. "net"
  23. "time"
  24. "github.com/ethereum/go-ethereum/common/mclock"
  25. "github.com/ethereum/go-ethereum/crypto"
  26. "github.com/ethereum/go-ethereum/p2p/discover/v5wire"
  27. "github.com/ethereum/go-ethereum/p2p/enode"
  28. "github.com/ethereum/go-ethereum/p2p/enr"
  29. )
  30. // readError represents an error during packet reading.
  31. // This exists to facilitate type-switching on the result of conn.read.
  32. type readError struct {
  33. err error
  34. }
  35. func (p *readError) Kind() byte { return 99 }
  36. func (p *readError) Name() string { return fmt.Sprintf("error: %v", p.err) }
  37. func (p *readError) Error() string { return p.err.Error() }
  38. func (p *readError) Unwrap() error { return p.err }
  39. func (p *readError) RequestID() []byte { return nil }
  40. func (p *readError) SetRequestID([]byte) {}
  41. // readErrorf creates a readError with the given text.
  42. func readErrorf(format string, args ...interface{}) *readError {
  43. return &readError{fmt.Errorf(format, args...)}
  44. }
  45. // This is the response timeout used in tests.
  46. const waitTime = 300 * time.Millisecond
  47. // conn is a connection to the node under test.
  48. type conn struct {
  49. localNode *enode.LocalNode
  50. localKey *ecdsa.PrivateKey
  51. remote *enode.Node
  52. remoteAddr *net.UDPAddr
  53. listeners []net.PacketConn
  54. log logger
  55. codec *v5wire.Codec
  56. lastRequest v5wire.Packet
  57. lastChallenge *v5wire.Whoareyou
  58. idCounter uint32
  59. }
  60. type logger interface {
  61. Logf(string, ...interface{})
  62. }
  63. // newConn sets up a connection to the given node.
  64. func newConn(dest *enode.Node, log logger) *conn {
  65. key, err := crypto.GenerateKey()
  66. if err != nil {
  67. panic(err)
  68. }
  69. db, err := enode.OpenDB("")
  70. if err != nil {
  71. panic(err)
  72. }
  73. ln := enode.NewLocalNode(db, key)
  74. return &conn{
  75. localKey: key,
  76. localNode: ln,
  77. remote: dest,
  78. remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()},
  79. codec: v5wire.NewCodec(ln, key, mclock.System{}),
  80. log: log,
  81. }
  82. }
  83. func (tc *conn) setEndpoint(c net.PacketConn) {
  84. tc.localNode.SetStaticIP(laddr(c).IP)
  85. tc.localNode.SetFallbackUDP(laddr(c).Port)
  86. }
  87. func (tc *conn) listen(ip string) net.PacketConn {
  88. l, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", ip))
  89. if err != nil {
  90. panic(err)
  91. }
  92. tc.listeners = append(tc.listeners, l)
  93. return l
  94. }
  95. // close shuts down all listeners and the local node.
  96. func (tc *conn) close() {
  97. for _, l := range tc.listeners {
  98. l.Close()
  99. }
  100. tc.localNode.Database().Close()
  101. }
  102. // nextReqID creates a request id.
  103. func (tc *conn) nextReqID() []byte {
  104. id := make([]byte, 4)
  105. tc.idCounter++
  106. binary.BigEndian.PutUint32(id, tc.idCounter)
  107. return id
  108. }
  109. // reqresp performs a request/response interaction on the given connection.
  110. // The request is retried if a handshake is requested.
  111. func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet {
  112. reqnonce := tc.write(c, req, nil)
  113. switch resp := tc.read(c).(type) {
  114. case *v5wire.Whoareyou:
  115. if resp.Nonce != reqnonce {
  116. return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:])
  117. }
  118. resp.Node = tc.remote
  119. tc.write(c, req, resp)
  120. return tc.read(c)
  121. default:
  122. return resp
  123. }
  124. }
  125. // findnode sends a FINDNODE request and waits for its responses.
  126. func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) {
  127. var (
  128. findnode = &v5wire.Findnode{ReqID: tc.nextReqID(), Distances: dists}
  129. reqnonce = tc.write(c, findnode, nil)
  130. first = true
  131. total uint8
  132. results []*enode.Node
  133. )
  134. for n := 1; n > 0; {
  135. switch resp := tc.read(c).(type) {
  136. case *v5wire.Whoareyou:
  137. // Handle handshake.
  138. if resp.Nonce == reqnonce {
  139. resp.Node = tc.remote
  140. tc.write(c, findnode, resp)
  141. } else {
  142. return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:])
  143. }
  144. case *v5wire.Ping:
  145. // Handle ping from remote.
  146. tc.write(c, &v5wire.Pong{
  147. ReqID: resp.ReqID,
  148. ENRSeq: tc.localNode.Seq(),
  149. }, nil)
  150. case *v5wire.Nodes:
  151. // Got NODES! Check request ID.
  152. if !bytes.Equal(resp.ReqID, findnode.ReqID) {
  153. return nil, fmt.Errorf("NODES response has wrong request id %x", resp.ReqID)
  154. }
  155. // Check total count. It should be greater than one
  156. // and needs to be the same across all responses.
  157. if first {
  158. if resp.Total == 0 || resp.Total > 6 {
  159. return nil, fmt.Errorf("invalid NODES response 'total' %d (not in (0,7))", resp.Total)
  160. }
  161. total = resp.Total
  162. n = int(total) - 1
  163. first = false
  164. } else {
  165. n--
  166. if resp.Total != total {
  167. return nil, fmt.Errorf("invalid NODES response 'total' %d (!= %d)", resp.Total, total)
  168. }
  169. }
  170. // Check nodes.
  171. nodes, err := checkRecords(resp.Nodes)
  172. if err != nil {
  173. return nil, fmt.Errorf("invalid node in NODES response: %v", err)
  174. }
  175. results = append(results, nodes...)
  176. default:
  177. return nil, fmt.Errorf("expected NODES, got %v", resp)
  178. }
  179. }
  180. return results, nil
  181. }
  182. // write sends a packet on the given connection.
  183. func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce {
  184. packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge)
  185. if err != nil {
  186. panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err))
  187. }
  188. if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil {
  189. tc.logf("Can't send %s: %v", p.Name(), err)
  190. } else {
  191. tc.logf(">> %s", p.Name())
  192. }
  193. return nonce
  194. }
  195. // read waits for an incoming packet on the given connection.
  196. func (tc *conn) read(c net.PacketConn) v5wire.Packet {
  197. buf := make([]byte, 1280)
  198. if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
  199. return &readError{err}
  200. }
  201. n, fromAddr, err := c.ReadFrom(buf)
  202. if err != nil {
  203. return &readError{err}
  204. }
  205. _, _, p, err := tc.codec.Decode(buf[:n], fromAddr.String())
  206. if err != nil {
  207. return &readError{err}
  208. }
  209. tc.logf("<< %s", p.Name())
  210. return p
  211. }
  212. // logf prints to the test log.
  213. func (tc *conn) logf(format string, args ...interface{}) {
  214. if tc.log != nil {
  215. tc.log.Logf("(%s) %s", tc.localNode.ID().TerminalString(), fmt.Sprintf(format, args...))
  216. }
  217. }
  218. func laddr(c net.PacketConn) *net.UDPAddr {
  219. return c.LocalAddr().(*net.UDPAddr)
  220. }
  221. func checkRecords(records []*enr.Record) ([]*enode.Node, error) {
  222. nodes := make([]*enode.Node, len(records))
  223. for i := range records {
  224. n, err := enode.New(enode.ValidSchemes, records[i])
  225. if err != nil {
  226. return nil, err
  227. }
  228. nodes[i] = n
  229. }
  230. return nodes, nil
  231. }
  232. func containsUint(ints []uint, x uint) bool {
  233. for i := range ints {
  234. if ints[i] == x {
  235. return true
  236. }
  237. }
  238. return false
  239. }