123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- package wsutil
- import (
- "bufio"
- "bytes"
- "context"
- "io"
- "io/ioutil"
- "net"
- "net/http"
- "github.com/gobwas/ws"
- )
- // DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
- // handshake. That is, it gives ability to receive copied HTTP request and
- // response bytes that made inside Dialer.Dial().
- //
- // Note that it must not be used in production applications that requires
- // Dial() to be efficient.
- type DebugDialer struct {
- // Dialer contains WebSocket connection establishment options.
- Dialer ws.Dialer
- // OnRequest and OnResponse are the callbacks that will be called with the
- // HTTP request and response respectively.
- OnRequest, OnResponse func([]byte)
- }
- // Dial connects to the url host and upgrades connection to WebSocket. It makes
- // it by calling d.Dialer.Dial().
- func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
- // Need to copy Dialer to prevent original object mutation.
- dialer := d.Dialer
- var (
- reqBuf bytes.Buffer
- resBuf bytes.Buffer
- resContentLength int64
- )
- userWrap := dialer.WrapConn
- dialer.WrapConn = func(c net.Conn) net.Conn {
- if userWrap != nil {
- c = userWrap(c)
- }
- // Save the pointer to the raw connection.
- conn = c
- var (
- r io.Reader = conn
- w io.Writer = conn
- )
- if d.OnResponse != nil {
- r = &prefetchResponseReader{
- source: conn,
- buffer: &resBuf,
- contentLength: &resContentLength,
- }
- }
- if d.OnRequest != nil {
- w = io.MultiWriter(conn, &reqBuf)
- }
- return rwConn{conn, r, w}
- }
- _, br, hs, err = dialer.Dial(ctx, urlstr)
- if onRequest := d.OnRequest; onRequest != nil {
- onRequest(reqBuf.Bytes())
- }
- if onResponse := d.OnResponse; onResponse != nil {
- // We must split response inside buffered bytes from other received
- // bytes from server.
- p := resBuf.Bytes()
- n := bytes.Index(p, headEnd)
- h := n + len(headEnd) // Head end index.
- n = h + int(resContentLength) // Body end index.
- onResponse(p[:n])
- if br != nil {
- // If br is non-nil, then it mean two things. First is that
- // handshake is OK and server has sent additional bytes – probably
- // immediate sent frames (or weird but possible response body).
- // Second, the bad one, is that br buffer's source is now rwConn
- // instance from above WrapConn call. It is incorrect, so we must
- // fix it.
- var r io.Reader = conn
- if len(p) > h {
- // Buffer contains more than just HTTP headers bytes.
- r = io.MultiReader(
- bytes.NewReader(p[h:]),
- conn,
- )
- }
- br.Reset(r)
- // Must make br.Buffered() to be non-zero.
- br.Peek(len(p[h:]))
- }
- }
- return conn, br, hs, err
- }
- type rwConn struct {
- net.Conn
- r io.Reader
- w io.Writer
- }
- func (rwc rwConn) Read(p []byte) (int, error) {
- return rwc.r.Read(p)
- }
- func (rwc rwConn) Write(p []byte) (int, error) {
- return rwc.w.Write(p)
- }
- var headEnd = []byte("\r\n\r\n")
- type prefetchResponseReader struct {
- source io.Reader // Original connection source.
- reader io.Reader // Wrapped reader used to read from by clients.
- buffer *bytes.Buffer
- contentLength *int64
- }
- func (r *prefetchResponseReader) Read(p []byte) (int, error) {
- if r.reader == nil {
- resp, err := http.ReadResponse(bufio.NewReader(
- io.TeeReader(r.source, r.buffer),
- ), nil)
- if err == nil {
- *r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
- resp.Body.Close()
- }
- bts := r.buffer.Bytes()
- r.reader = io.MultiReader(
- bytes.NewReader(bts),
- r.source,
- )
- }
- return r.reader.Read(p)
- }
|