dialer.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/url"
  11. "strconv"
  12. "strings"
  13. "time"
  14. "github.com/gobwas/httphead"
  15. "github.com/gobwas/pool/pbufio"
  16. )
  17. // Constants used by Dialer.
  18. const (
  19. DefaultClientReadBufferSize = 4096
  20. DefaultClientWriteBufferSize = 4096
  21. )
  22. // Handshake represents handshake result.
  23. type Handshake struct {
  24. // Protocol is the subprotocol selected during handshake.
  25. Protocol string
  26. // Extensions is the list of negotiated extensions.
  27. Extensions []httphead.Option
  28. }
  29. // Errors used by the websocket client.
  30. var (
  31. ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
  32. ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
  33. ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
  34. )
  35. // DefaultDialer is dialer that holds no options and is used by Dial function.
  36. var DefaultDialer Dialer
  37. // Dial is like Dialer{}.Dial().
  38. func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
  39. return DefaultDialer.Dial(ctx, urlstr)
  40. }
  41. // Dialer contains options for establishing websocket connection to an url.
  42. type Dialer struct {
  43. // ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
  44. // They used to read and write http data while upgrading to WebSocket.
  45. // Allocated buffers are pooled with sync.Pool to avoid extra allocations.
  46. //
  47. // If a size is zero then default value is used.
  48. ReadBufferSize, WriteBufferSize int
  49. // Timeout is the maximum amount of time a Dial() will wait for a connect
  50. // and an handshake to complete.
  51. //
  52. // The default is no timeout.
  53. Timeout time.Duration
  54. // Protocols is the list of subprotocols that the client wants to speak,
  55. // ordered by preference.
  56. //
  57. // See https://tools.ietf.org/html/rfc6455#section-4.1
  58. Protocols []string
  59. // Extensions is the list of extensions that client wants to speak.
  60. //
  61. // Note that if server decides to use some of this extensions, Dial() will
  62. // return Handshake struct containing a slice of items, which are the
  63. // shallow copies of the items from this list. That is, internals of
  64. // Extensions items are shared during Dial().
  65. //
  66. // See https://tools.ietf.org/html/rfc6455#section-4.1
  67. // See https://tools.ietf.org/html/rfc6455#section-9.1
  68. Extensions []httphead.Option
  69. // Header is an optional HandshakeHeader instance that could be used to
  70. // write additional headers to the handshake request.
  71. //
  72. // It used instead of any key-value mappings to avoid allocations in user
  73. // land.
  74. Header HandshakeHeader
  75. // OnStatusError is the callback that will be called after receiving non
  76. // "101 Continue" HTTP response status. It receives an io.Reader object
  77. // representing server response bytes. That is, it gives ability to parse
  78. // HTTP response somehow (probably with http.ReadResponse call) and make a
  79. // decision of further logic.
  80. //
  81. // The arguments are only valid until the callback returns.
  82. OnStatusError func(status int, reason []byte, resp io.Reader)
  83. // OnHeader is the callback that will be called after successful parsing of
  84. // header, that is not used during WebSocket handshake procedure. That is,
  85. // it will be called with non-websocket headers, which could be relevant
  86. // for application-level logic.
  87. //
  88. // The arguments are only valid until the callback returns.
  89. //
  90. // Returned value could be used to prevent processing response.
  91. OnHeader func(key, value []byte) (err error)
  92. // NetDial is the function that is used to get plain tcp connection.
  93. // If it is not nil, then it is used instead of net.Dialer.
  94. NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
  95. // TLSClient is the callback that will be called after successful dial with
  96. // received connection and its remote host name. If it is nil, then the
  97. // default tls.Client() will be used.
  98. // If it is not nil, then TLSConfig field is ignored.
  99. TLSClient func(conn net.Conn, hostname string) net.Conn
  100. // TLSConfig is passed to tls.Client() to start TLS over established
  101. // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
  102. // non-nil and its ServerName is empty, then for every Dial() it will be
  103. // cloned and appropriate ServerName will be set.
  104. TLSConfig *tls.Config
  105. // WrapConn is the optional callback that will be called when connection is
  106. // ready for an i/o. That is, it will be called after successful dial and
  107. // TLS initialization (for "wss" schemes). It may be helpful for different
  108. // user land purposes such as end to end encryption.
  109. //
  110. // Note that for debugging purposes of an http handshake (e.g. sent request
  111. // and received response), there is an wsutil.DebugDialer struct.
  112. WrapConn func(conn net.Conn) net.Conn
  113. }
  114. // Dial connects to the url host and upgrades connection to WebSocket.
  115. //
  116. // If server has sent frames right after successful handshake then returned
  117. // buffer will be non-nil. In other cases buffer is always nil. For better
  118. // memory efficiency received non-nil bufio.Reader should be returned to the
  119. // inner pool with PutReader() function after use.
  120. //
  121. // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
  122. // If you want to dial non-ascii host name, take care of its name serialization
  123. // avoiding bad request issues. For more info see net/http Request.Write()
  124. // implementation, especially cleanHost() function.
  125. func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
  126. u, err := url.ParseRequestURI(urlstr)
  127. if err != nil {
  128. return
  129. }
  130. // Prepare context to dial with. Initially it is the same as original, but
  131. // if d.Timeout is non-zero and points to time that is before ctx.Deadline,
  132. // we use more shorter context for dial.
  133. dialctx := ctx
  134. var deadline time.Time
  135. if t := d.Timeout; t != 0 {
  136. deadline = time.Now().Add(t)
  137. if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
  138. var cancel context.CancelFunc
  139. dialctx, cancel = context.WithDeadline(ctx, deadline)
  140. defer cancel()
  141. }
  142. }
  143. if conn, err = d.dial(dialctx, u); err != nil {
  144. return
  145. }
  146. defer func() {
  147. if err != nil {
  148. conn.Close()
  149. }
  150. }()
  151. if ctx == context.Background() {
  152. // No need to start I/O interrupter goroutine which is not zero-cost.
  153. conn.SetDeadline(deadline)
  154. defer conn.SetDeadline(noDeadline)
  155. } else {
  156. // Context could be canceled or its deadline could be exceeded.
  157. // Start the interrupter goroutine to handle context cancelation.
  158. done := setupContextDeadliner(ctx, conn)
  159. defer func() {
  160. // Map Upgrade() error to a possible context expiration error. That
  161. // is, even if Upgrade() err is nil, context could be already
  162. // expired and connection be "poisoned" by SetDeadline() call.
  163. // In that case we must not return ctx.Err() error.
  164. done(&err)
  165. }()
  166. }
  167. br, hs, err = d.Upgrade(conn, u)
  168. return
  169. }
  170. var (
  171. // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
  172. // Dialer.NetDial is not provided.
  173. netEmptyDialer net.Dialer
  174. // tlsEmptyConfig is an empty tls.Config used as default one.
  175. tlsEmptyConfig tls.Config
  176. )
  177. func tlsDefaultConfig() *tls.Config {
  178. return &tlsEmptyConfig
  179. }
  180. func hostport(host string, defaultPort string) (hostname, addr string) {
  181. var (
  182. colon = strings.LastIndexByte(host, ':')
  183. bracket = strings.IndexByte(host, ']')
  184. )
  185. if colon > bracket {
  186. return host[:colon], host
  187. }
  188. return host, host + defaultPort
  189. }
  190. func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
  191. dial := d.NetDial
  192. if dial == nil {
  193. dial = netEmptyDialer.DialContext
  194. }
  195. switch u.Scheme {
  196. case "ws":
  197. _, addr := hostport(u.Host, ":80")
  198. conn, err = dial(ctx, "tcp", addr)
  199. case "wss":
  200. hostname, addr := hostport(u.Host, ":443")
  201. conn, err = dial(ctx, "tcp", addr)
  202. if err != nil {
  203. return
  204. }
  205. tlsClient := d.TLSClient
  206. if tlsClient == nil {
  207. tlsClient = d.tlsClient
  208. }
  209. conn = tlsClient(conn, hostname)
  210. default:
  211. return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
  212. }
  213. if wrap := d.WrapConn; wrap != nil {
  214. conn = wrap(conn)
  215. }
  216. return
  217. }
  218. func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
  219. config := d.TLSConfig
  220. if config == nil {
  221. config = tlsDefaultConfig()
  222. }
  223. if config.ServerName == "" {
  224. config = tlsCloneConfig(config)
  225. config.ServerName = hostname
  226. }
  227. // Do not make conn.Handshake() here because downstairs we will prepare
  228. // i/o on this conn with proper context's timeout handling.
  229. return tls.Client(conn, config)
  230. }
  231. var (
  232. // This variables are set like in net/net.go.
  233. // noDeadline is just zero value for readability.
  234. noDeadline = time.Time{}
  235. // aLongTimeAgo is a non-zero time, far in the past, used for immediate
  236. // cancelation of dials.
  237. aLongTimeAgo = time.Unix(42, 0)
  238. )
  239. // Upgrade writes an upgrade request to the given io.ReadWriter conn at given
  240. // url u and reads a response from it.
  241. //
  242. // It is a caller responsibility to manage I/O deadlines on conn.
  243. //
  244. // It returns handshake info and some bytes which could be written by the peer
  245. // right after response and be caught by us during buffered read.
  246. func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
  247. // headerSeen constants helps to report whether or not some header was seen
  248. // during reading request bytes.
  249. const (
  250. headerSeenUpgrade = 1 << iota
  251. headerSeenConnection
  252. headerSeenSecAccept
  253. // headerSeenAll is the value that we expect to receive at the end of
  254. // headers read/parse loop.
  255. headerSeenAll = 0 |
  256. headerSeenUpgrade |
  257. headerSeenConnection |
  258. headerSeenSecAccept
  259. )
  260. br = pbufio.GetReader(conn,
  261. nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
  262. )
  263. bw := pbufio.GetWriter(conn,
  264. nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
  265. )
  266. defer func() {
  267. pbufio.PutWriter(bw)
  268. if br.Buffered() == 0 || err != nil {
  269. // Server does not wrote additional bytes to the connection or
  270. // error occurred. That is, no reason to return buffer.
  271. pbufio.PutReader(br)
  272. br = nil
  273. }
  274. }()
  275. nonce := make([]byte, nonceSize)
  276. initNonce(nonce)
  277. httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
  278. if err = bw.Flush(); err != nil {
  279. return
  280. }
  281. // Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
  282. sl, err := readLine(br)
  283. if err != nil {
  284. return
  285. }
  286. // Begin validation of the response.
  287. // See https://tools.ietf.org/html/rfc6455#section-4.2.2
  288. // Parse request line data like HTTP version, uri and method.
  289. resp, err := httpParseResponseLine(sl)
  290. if err != nil {
  291. return
  292. }
  293. // Even if RFC says "1.1 or higher" without mentioning the part of the
  294. // version, we apply it only to minor part.
  295. if resp.major != 1 || resp.minor < 1 {
  296. err = ErrHandshakeBadProtocol
  297. return
  298. }
  299. if resp.status != 101 {
  300. err = StatusError(resp.status)
  301. if onStatusError := d.OnStatusError; onStatusError != nil {
  302. // Invoke callback with multireader of status-line bytes br.
  303. onStatusError(resp.status, resp.reason,
  304. io.MultiReader(
  305. bytes.NewReader(sl),
  306. strings.NewReader(crlf),
  307. br,
  308. ),
  309. )
  310. }
  311. return
  312. }
  313. // If response status is 101 then we expect all technical headers to be
  314. // valid. If not, then we stop processing response without giving user
  315. // ability to read non-technical headers. That is, we do not distinguish
  316. // technical errors (such as parsing error) and protocol errors.
  317. var headerSeen byte
  318. for {
  319. line, e := readLine(br)
  320. if e != nil {
  321. err = e
  322. return
  323. }
  324. if len(line) == 0 {
  325. // Blank line, no more lines to read.
  326. break
  327. }
  328. k, v, ok := httpParseHeaderLine(line)
  329. if !ok {
  330. err = ErrMalformedResponse
  331. return
  332. }
  333. switch btsToString(k) {
  334. case headerUpgradeCanonical:
  335. headerSeen |= headerSeenUpgrade
  336. if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
  337. err = ErrHandshakeBadUpgrade
  338. return
  339. }
  340. case headerConnectionCanonical:
  341. headerSeen |= headerSeenConnection
  342. // Note that as RFC6455 says:
  343. // > A |Connection| header field with value "Upgrade".
  344. // That is, in server side, "Connection" header could contain
  345. // multiple token. But in response it must contains exactly one.
  346. if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
  347. err = ErrHandshakeBadConnection
  348. return
  349. }
  350. case headerSecAcceptCanonical:
  351. headerSeen |= headerSeenSecAccept
  352. if !checkAcceptFromNonce(v, nonce) {
  353. err = ErrHandshakeBadSecAccept
  354. return
  355. }
  356. case headerSecProtocolCanonical:
  357. // RFC6455 1.3:
  358. // "The server selects one or none of the acceptable protocols
  359. // and echoes that value in its handshake to indicate that it has
  360. // selected that protocol."
  361. for _, want := range d.Protocols {
  362. if string(v) == want {
  363. hs.Protocol = want
  364. break
  365. }
  366. }
  367. if hs.Protocol == "" {
  368. // Server echoed subprotocol that is not present in client
  369. // requested protocols.
  370. err = ErrHandshakeBadSubProtocol
  371. return
  372. }
  373. case headerSecExtensionsCanonical:
  374. hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
  375. if err != nil {
  376. return
  377. }
  378. default:
  379. if onHeader := d.OnHeader; onHeader != nil {
  380. if e := onHeader(k, v); e != nil {
  381. err = e
  382. return
  383. }
  384. }
  385. }
  386. }
  387. if err == nil && headerSeen != headerSeenAll {
  388. switch {
  389. case headerSeen&headerSeenUpgrade == 0:
  390. err = ErrHandshakeBadUpgrade
  391. case headerSeen&headerSeenConnection == 0:
  392. err = ErrHandshakeBadConnection
  393. case headerSeen&headerSeenSecAccept == 0:
  394. err = ErrHandshakeBadSecAccept
  395. default:
  396. panic("unknown headers state")
  397. }
  398. }
  399. return
  400. }
  401. // PutReader returns bufio.Reader instance to the inner reuse pool.
  402. // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
  403. // contains unprocessed buffered data, that was sent by the server quickly
  404. // right after handshake.
  405. func PutReader(br *bufio.Reader) {
  406. pbufio.PutReader(br)
  407. }
  408. // StatusError contains an unexpected status-line code from the server.
  409. type StatusError int
  410. func (s StatusError) Error() string {
  411. return "unexpected HTTP response status: " + strconv.Itoa(int(s))
  412. }
  413. func isTimeoutError(err error) bool {
  414. t, ok := err.(net.Error)
  415. return ok && t.Timeout()
  416. }
  417. func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
  418. if len(selected) == 0 {
  419. return received, nil
  420. }
  421. var (
  422. index int
  423. option httphead.Option
  424. err error
  425. )
  426. index = -1
  427. match := func() (ok bool) {
  428. for _, want := range wanted {
  429. // A server accepts one or more extensions by including a
  430. // |Sec-WebSocket-Extensions| header field containing one or more
  431. // extensions that were requested by the client.
  432. //
  433. // The interpretation of any extension parameters, and what
  434. // constitutes a valid response by a server to a requested set of
  435. // parameters by a client, will be defined by each such extension.
  436. if bytes.Equal(option.Name, want.Name) {
  437. // Check parsed extension to be present in client
  438. // requested extensions. We move matched extension
  439. // from client list to avoid allocation.
  440. received = append(received, option)
  441. return true
  442. }
  443. }
  444. return false
  445. }
  446. ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
  447. if i != index {
  448. // Met next option.
  449. index = i
  450. if i != 0 && !match() {
  451. // Server returned non-requested extension.
  452. err = ErrHandshakeBadExtensions
  453. return httphead.ControlBreak
  454. }
  455. option = httphead.Option{Name: name}
  456. }
  457. if attr != nil {
  458. option.Parameters.Set(attr, val)
  459. }
  460. return httphead.ControlContinue
  461. })
  462. if !ok {
  463. err = ErrMalformedResponse
  464. return received, err
  465. }
  466. if !match() {
  467. return received, ErrHandshakeBadExtensions
  468. }
  469. return received, err
  470. }
  471. // setupContextDeadliner is a helper function that starts connection I/O
  472. // interrupter goroutine.
  473. //
  474. // Started goroutine calls SetDeadline() with long time ago value when context
  475. // become expired to make any I/O operations failed. It returns done function
  476. // that stops started goroutine and maps error received from conn I/O methods
  477. // to possible context expiration error.
  478. //
  479. // In concern with possible SetDeadline() call inside interrupter goroutine,
  480. // caller passes pointer to its I/O error (even if it is nil) to done(&err).
  481. // That is, even if I/O error is nil, context could be already expired and
  482. // connection "poisoned" by SetDeadline() call. In that case done(&err) will
  483. // store at *err ctx.Err() result. If err is caused not by timeout, it will
  484. // leaved untouched.
  485. func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
  486. var (
  487. quit = make(chan struct{})
  488. interrupt = make(chan error, 1)
  489. )
  490. go func() {
  491. select {
  492. case <-quit:
  493. interrupt <- nil
  494. case <-ctx.Done():
  495. // Cancel i/o immediately.
  496. conn.SetDeadline(aLongTimeAgo)
  497. interrupt <- ctx.Err()
  498. }
  499. }()
  500. return func(err *error) {
  501. close(quit)
  502. // If ctx.Err() is non-nil and the original err is net.Error with
  503. // Timeout() == true, then it means that I/O was canceled by us by
  504. // SetDeadline(aLongTimeAgo) call, or by somebody else previously
  505. // by conn.SetDeadline(x).
  506. //
  507. // Even on race condition when both deadlines are expired
  508. // (SetDeadline() made not by us and context's), we prefer ctx.Err() to
  509. // be returned.
  510. if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
  511. *err = ctxErr
  512. }
  513. }
  514. }