dialer.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. package wsutil
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "net/http"
  10. "github.com/gobwas/ws"
  11. )
  12. // DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
  13. // handshake. That is, it gives ability to receive copied HTTP request and
  14. // response bytes that made inside Dialer.Dial().
  15. //
  16. // Note that it must not be used in production applications that requires
  17. // Dial() to be efficient.
  18. type DebugDialer struct {
  19. // Dialer contains WebSocket connection establishment options.
  20. Dialer ws.Dialer
  21. // OnRequest and OnResponse are the callbacks that will be called with the
  22. // HTTP request and response respectively.
  23. OnRequest, OnResponse func([]byte)
  24. }
  25. // Dial connects to the url host and upgrades connection to WebSocket. It makes
  26. // it by calling d.Dialer.Dial().
  27. func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
  28. // Need to copy Dialer to prevent original object mutation.
  29. dialer := d.Dialer
  30. var (
  31. reqBuf bytes.Buffer
  32. resBuf bytes.Buffer
  33. resContentLength int64
  34. )
  35. userWrap := dialer.WrapConn
  36. dialer.WrapConn = func(c net.Conn) net.Conn {
  37. if userWrap != nil {
  38. c = userWrap(c)
  39. }
  40. // Save the pointer to the raw connection.
  41. conn = c
  42. var (
  43. r io.Reader = conn
  44. w io.Writer = conn
  45. )
  46. if d.OnResponse != nil {
  47. r = &prefetchResponseReader{
  48. source: conn,
  49. buffer: &resBuf,
  50. contentLength: &resContentLength,
  51. }
  52. }
  53. if d.OnRequest != nil {
  54. w = io.MultiWriter(conn, &reqBuf)
  55. }
  56. return rwConn{conn, r, w}
  57. }
  58. _, br, hs, err = dialer.Dial(ctx, urlstr)
  59. if onRequest := d.OnRequest; onRequest != nil {
  60. onRequest(reqBuf.Bytes())
  61. }
  62. if onResponse := d.OnResponse; onResponse != nil {
  63. // We must split response inside buffered bytes from other received
  64. // bytes from server.
  65. p := resBuf.Bytes()
  66. n := bytes.Index(p, headEnd)
  67. h := n + len(headEnd) // Head end index.
  68. n = h + int(resContentLength) // Body end index.
  69. onResponse(p[:n])
  70. if br != nil {
  71. // If br is non-nil, then it mean two things. First is that
  72. // handshake is OK and server has sent additional bytes – probably
  73. // immediate sent frames (or weird but possible response body).
  74. // Second, the bad one, is that br buffer's source is now rwConn
  75. // instance from above WrapConn call. It is incorrect, so we must
  76. // fix it.
  77. var r io.Reader = conn
  78. if len(p) > h {
  79. // Buffer contains more than just HTTP headers bytes.
  80. r = io.MultiReader(
  81. bytes.NewReader(p[h:]),
  82. conn,
  83. )
  84. }
  85. br.Reset(r)
  86. // Must make br.Buffered() to be non-zero.
  87. br.Peek(len(p[h:]))
  88. }
  89. }
  90. return conn, br, hs, err
  91. }
  92. type rwConn struct {
  93. net.Conn
  94. r io.Reader
  95. w io.Writer
  96. }
  97. func (rwc rwConn) Read(p []byte) (int, error) {
  98. return rwc.r.Read(p)
  99. }
  100. func (rwc rwConn) Write(p []byte) (int, error) {
  101. return rwc.w.Write(p)
  102. }
  103. var headEnd = []byte("\r\n\r\n")
  104. type prefetchResponseReader struct {
  105. source io.Reader // Original connection source.
  106. reader io.Reader // Wrapped reader used to read from by clients.
  107. buffer *bytes.Buffer
  108. contentLength *int64
  109. }
  110. func (r *prefetchResponseReader) Read(p []byte) (int, error) {
  111. if r.reader == nil {
  112. resp, err := http.ReadResponse(bufio.NewReader(
  113. io.TeeReader(r.source, r.buffer),
  114. ), nil)
  115. if err == nil {
  116. *r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
  117. resp.Body.Close()
  118. }
  119. bts := r.buffer.Bytes()
  120. r.reader = io.MultiReader(
  121. bytes.NewReader(bts),
  122. r.source,
  123. )
  124. }
  125. return r.reader.Read(p)
  126. }