server.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "github.com/gobwas/httphead"
  12. "github.com/gobwas/pool/pbufio"
  13. )
  14. // Constants used by ConnUpgrader.
  15. const (
  16. DefaultServerReadBufferSize = 4096
  17. DefaultServerWriteBufferSize = 512
  18. )
  19. // Errors used by both client and server when preparing WebSocket handshake.
  20. var (
  21. ErrHandshakeBadProtocol = RejectConnectionError(
  22. RejectionStatus(http.StatusHTTPVersionNotSupported),
  23. RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")),
  24. )
  25. ErrHandshakeBadMethod = RejectConnectionError(
  26. RejectionStatus(http.StatusMethodNotAllowed),
  27. RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")),
  28. )
  29. ErrHandshakeBadHost = RejectConnectionError(
  30. RejectionStatus(http.StatusBadRequest),
  31. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
  32. )
  33. ErrHandshakeBadUpgrade = RejectConnectionError(
  34. RejectionStatus(http.StatusBadRequest),
  35. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
  36. )
  37. ErrHandshakeBadConnection = RejectConnectionError(
  38. RejectionStatus(http.StatusBadRequest),
  39. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
  40. )
  41. ErrHandshakeBadSecAccept = RejectConnectionError(
  42. RejectionStatus(http.StatusBadRequest),
  43. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
  44. )
  45. ErrHandshakeBadSecKey = RejectConnectionError(
  46. RejectionStatus(http.StatusBadRequest),
  47. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
  48. )
  49. ErrHandshakeBadSecVersion = RejectConnectionError(
  50. RejectionStatus(http.StatusBadRequest),
  51. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
  52. )
  53. )
  54. // ErrMalformedResponse is returned by Dialer to indicate that server response
  55. // can not be parsed.
  56. var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
  57. // ErrMalformedRequest is returned when HTTP request can not be parsed.
  58. var ErrMalformedRequest = RejectConnectionError(
  59. RejectionStatus(http.StatusBadRequest),
  60. RejectionReason("malformed HTTP request"),
  61. )
  62. // ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
  63. // connection is rejected because given WebSocket version is malformed.
  64. //
  65. // According to RFC6455:
  66. // If this version does not match a version understood by the server, the
  67. // server MUST abort the WebSocket handshake described in this section and
  68. // instead send an appropriate HTTP error code (such as 426 Upgrade Required)
  69. // and a |Sec-WebSocket-Version| header field indicating the version(s) the
  70. // server is capable of understanding.
  71. var ErrHandshakeUpgradeRequired = RejectConnectionError(
  72. RejectionStatus(http.StatusUpgradeRequired),
  73. RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
  74. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
  75. )
  76. // ErrNotHijacker is an error returned when http.ResponseWriter does not
  77. // implement http.Hijacker interface.
  78. var ErrNotHijacker = RejectConnectionError(
  79. RejectionStatus(http.StatusInternalServerError),
  80. RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
  81. )
  82. // DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
  83. // UpgradeHTTP function.
  84. var DefaultHTTPUpgrader HTTPUpgrader
  85. // UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
  86. func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
  87. return DefaultHTTPUpgrader.Upgrade(r, w)
  88. }
  89. // DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
  90. // function.
  91. var DefaultUpgrader Upgrader
  92. // Upgrade is like Upgrader{}.Upgrade().
  93. func Upgrade(conn io.ReadWriter) (Handshake, error) {
  94. return DefaultUpgrader.Upgrade(conn)
  95. }
  96. // HTTPUpgrader contains options for upgrading connection to websocket from
  97. // net/http Handler arguments.
  98. type HTTPUpgrader struct {
  99. // Timeout is the maximum amount of time an Upgrade() will spent while
  100. // writing handshake response.
  101. //
  102. // The default is no timeout.
  103. Timeout time.Duration
  104. // Header is an optional http.Header mapping that could be used to
  105. // write additional headers to the handshake response.
  106. //
  107. // Note that if present, it will be written in any result of handshake.
  108. Header http.Header
  109. // Protocol is the select function that is used to select subprotocol from
  110. // list requested by client. If this field is set, then the first matched
  111. // protocol is sent to a client as negotiated.
  112. Protocol func(string) bool
  113. // Extension is the select function that is used to select extensions from
  114. // list requested by client. If this field is set, then the all matched
  115. // extensions are sent to a client as negotiated.
  116. //
  117. // DEPRECATED. Use Negotiate instead.
  118. Extension func(httphead.Option) bool
  119. // Negotiate is the callback that is used to negotiate extensions from
  120. // the client's offer. If this field is set, then the returned non-zero
  121. // extensions are sent to the client as accepted extensions in the
  122. // response.
  123. //
  124. // The argument is only valid until the Negotiate callback returns.
  125. //
  126. // If returned error is non-nil then connection is rejected and response is
  127. // sent with appropriate HTTP error code and body set to error message.
  128. //
  129. // RejectConnectionError could be used to get more control on response.
  130. Negotiate func(httphead.Option) (httphead.Option, error)
  131. }
  132. // Upgrade upgrades http connection to the websocket connection.
  133. //
  134. // It hijacks net.Conn from w and returns received net.Conn and
  135. // bufio.ReadWriter. On successful handshake it returns Handshake struct
  136. // describing handshake info.
  137. func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
  138. // Hijack connection first to get the ability to write rejection errors the
  139. // same way as in Upgrader.
  140. hj, ok := w.(http.Hijacker)
  141. if ok {
  142. conn, rw, err = hj.Hijack()
  143. } else {
  144. err = ErrNotHijacker
  145. }
  146. if err != nil {
  147. httpError(w, err.Error(), http.StatusInternalServerError)
  148. return
  149. }
  150. // See https://tools.ietf.org/html/rfc6455#section-4.1
  151. // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
  152. var nonce string
  153. if r.Method != http.MethodGet {
  154. err = ErrHandshakeBadMethod
  155. } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
  156. err = ErrHandshakeBadProtocol
  157. } else if r.Host == "" {
  158. err = ErrHandshakeBadHost
  159. } else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
  160. err = ErrHandshakeBadUpgrade
  161. } else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
  162. err = ErrHandshakeBadConnection
  163. } else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
  164. err = ErrHandshakeBadSecKey
  165. } else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
  166. // According to RFC6455:
  167. //
  168. // If this version does not match a version understood by the server,
  169. // the server MUST abort the WebSocket handshake described in this
  170. // section and instead send an appropriate HTTP error code (such as 426
  171. // Upgrade Required) and a |Sec-WebSocket-Version| header field
  172. // indicating the version(s) the server is capable of understanding.
  173. //
  174. // So we branching here cause empty or not present version does not
  175. // meet the ABNF rules of RFC6455:
  176. //
  177. // version = DIGIT | (NZDIGIT DIGIT) |
  178. // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
  179. // ; Limited to 0-255 range, with no leading zeros
  180. //
  181. // That is, if version is really invalid – we sent 426 status, if it
  182. // not present or empty – it is 400.
  183. if v != "" {
  184. err = ErrHandshakeUpgradeRequired
  185. } else {
  186. err = ErrHandshakeBadSecVersion
  187. }
  188. }
  189. if check := u.Protocol; err == nil && check != nil {
  190. ps := r.Header[headerSecProtocolCanonical]
  191. for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
  192. var ok bool
  193. hs.Protocol, ok = strSelectProtocol(ps[i], check)
  194. if !ok {
  195. err = ErrMalformedRequest
  196. }
  197. }
  198. }
  199. if f := u.Negotiate; err == nil && f != nil {
  200. for _, h := range r.Header[headerSecExtensionsCanonical] {
  201. hs.Extensions, err = negotiateExtensions(strToBytes(h), hs.Extensions, f)
  202. if err != nil {
  203. break
  204. }
  205. }
  206. }
  207. // DEPRECATED path.
  208. if check := u.Extension; err == nil && check != nil && u.Negotiate == nil {
  209. xs := r.Header[headerSecExtensionsCanonical]
  210. for i := 0; i < len(xs) && err == nil; i++ {
  211. var ok bool
  212. hs.Extensions, ok = btsSelectExtensions(strToBytes(xs[i]), hs.Extensions, check)
  213. if !ok {
  214. err = ErrMalformedRequest
  215. }
  216. }
  217. }
  218. // Clear deadlines set by server.
  219. conn.SetDeadline(noDeadline)
  220. if t := u.Timeout; t != 0 {
  221. conn.SetWriteDeadline(time.Now().Add(t))
  222. defer conn.SetWriteDeadline(noDeadline)
  223. }
  224. var header handshakeHeader
  225. if h := u.Header; h != nil {
  226. header[0] = HandshakeHeaderHTTP(h)
  227. }
  228. if err == nil {
  229. httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
  230. err = rw.Writer.Flush()
  231. } else {
  232. var code int
  233. if rej, ok := err.(*rejectConnectionError); ok {
  234. code = rej.code
  235. header[1] = rej.header
  236. }
  237. if code == 0 {
  238. code = http.StatusInternalServerError
  239. }
  240. httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
  241. // Do not store Flush() error to not override already existing one.
  242. rw.Writer.Flush()
  243. }
  244. return
  245. }
  246. // Upgrader contains options for upgrading connection to websocket.
  247. type Upgrader struct {
  248. // ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
  249. // They used to read and write http data while upgrading to WebSocket.
  250. // Allocated buffers are pooled with sync.Pool to avoid extra allocations.
  251. //
  252. // If a size is zero then default value is used.
  253. //
  254. // Usually it is useful to set read buffer size bigger than write buffer
  255. // size because incoming request could contain long header values, such as
  256. // Cookie. Response, in other way, could be big only if user write multiple
  257. // custom headers. Usually response takes less than 256 bytes.
  258. ReadBufferSize, WriteBufferSize int
  259. // Protocol is a select function that is used to select subprotocol
  260. // from list requested by client. If this field is set, then the first matched
  261. // protocol is sent to a client as negotiated.
  262. //
  263. // The argument is only valid until the callback returns.
  264. Protocol func([]byte) bool
  265. // ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
  266. // Note that returned bytes must be valid until Upgrade returns.
  267. // If ProtocolCustom is set, it used instead of Protocol function.
  268. ProtocolCustom func([]byte) (string, bool)
  269. // Extension is a select function that is used to select extensions
  270. // from list requested by client. If this field is set, then the all matched
  271. // extensions are sent to a client as negotiated.
  272. //
  273. // Note that Extension may be called multiple times and implementations
  274. // must track uniqueness of accepted extensions manually.
  275. //
  276. // The argument is only valid until the callback returns.
  277. //
  278. // According to the RFC6455 order of extensions passed by a client is
  279. // significant. That is, returning true from this function means that no
  280. // other extension with the same name should be checked because server
  281. // accepted the most preferable extension right now:
  282. // "Note that the order of extensions is significant. Any interactions between
  283. // multiple extensions MAY be defined in the documents defining the extensions.
  284. // In the absence of such definitions, the interpretation is that the header
  285. // fields listed by the client in its request represent a preference of the
  286. // header fields it wishes to use, with the first options listed being most
  287. // preferable."
  288. //
  289. // DEPRECATED. Use Negotiate instead.
  290. Extension func(httphead.Option) bool
  291. // ExtensionCustom allow user to parse Sec-WebSocket-Extensions header
  292. // manually.
  293. //
  294. // If ExtensionCustom() decides to accept received extension, it must
  295. // append appropriate option to the given slice of httphead.Option.
  296. // It returns results of append() to the given slice and a flag that
  297. // reports whether given header value is wellformed or not.
  298. //
  299. // Note that ExtensionCustom may be called multiple times and
  300. // implementations must track uniqueness of accepted extensions manually.
  301. //
  302. // Note that returned options should be valid until Upgrade returns.
  303. // If ExtensionCustom is set, it used instead of Extension function.
  304. ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
  305. // Negotiate is the callback that is used to negotiate extensions from
  306. // the client's offer. If this field is set, then the returned non-zero
  307. // extensions are sent to the client as accepted extensions in the
  308. // response.
  309. //
  310. // The argument is only valid until the Negotiate callback returns.
  311. //
  312. // If returned error is non-nil then connection is rejected and response is
  313. // sent with appropriate HTTP error code and body set to error message.
  314. //
  315. // RejectConnectionError could be used to get more control on response.
  316. Negotiate func(httphead.Option) (httphead.Option, error)
  317. // Header is an optional HandshakeHeader instance that could be used to
  318. // write additional headers to the handshake response.
  319. //
  320. // It used instead of any key-value mappings to avoid allocations in user
  321. // land.
  322. //
  323. // Note that if present, it will be written in any result of handshake.
  324. Header HandshakeHeader
  325. // OnRequest is a callback that will be called after request line
  326. // successful parsing.
  327. //
  328. // The arguments are only valid until the callback returns.
  329. //
  330. // If returned error is non-nil then connection is rejected and response is
  331. // sent with appropriate HTTP error code and body set to error message.
  332. //
  333. // RejectConnectionError could be used to get more control on response.
  334. OnRequest func(uri []byte) error
  335. // OnHost is a callback that will be called after "Host" header successful
  336. // parsing.
  337. //
  338. // It is separated from OnHeader callback because the Host header must be
  339. // present in each request since HTTP/1.1. Thus Host header is non-optional
  340. // and required for every WebSocket handshake.
  341. //
  342. // The arguments are only valid until the callback returns.
  343. //
  344. // If returned error is non-nil then connection is rejected and response is
  345. // sent with appropriate HTTP error code and body set to error message.
  346. //
  347. // RejectConnectionError could be used to get more control on response.
  348. OnHost func(host []byte) error
  349. // OnHeader is a callback that will be called after successful parsing of
  350. // header, that is not used during WebSocket handshake procedure. That is,
  351. // it will be called with non-websocket headers, which could be relevant
  352. // for application-level logic.
  353. //
  354. // The arguments are only valid until the callback returns.
  355. //
  356. // If returned error is non-nil then connection is rejected and response is
  357. // sent with appropriate HTTP error code and body set to error message.
  358. //
  359. // RejectConnectionError could be used to get more control on response.
  360. OnHeader func(key, value []byte) error
  361. // OnBeforeUpgrade is a callback that will be called before sending
  362. // successful upgrade response.
  363. //
  364. // Setting OnBeforeUpgrade allows user to make final application-level
  365. // checks and decide whether this connection is allowed to successfully
  366. // upgrade to WebSocket.
  367. //
  368. // It must return non-nil either HandshakeHeader or error and never both.
  369. //
  370. // If returned error is non-nil then connection is rejected and response is
  371. // sent with appropriate HTTP error code and body set to error message.
  372. //
  373. // RejectConnectionError could be used to get more control on response.
  374. OnBeforeUpgrade func() (header HandshakeHeader, err error)
  375. }
  376. // Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
  377. // as connection with incoming HTTP Upgrade request.
  378. //
  379. // It is a caller responsibility to manage i/o timeouts on conn.
  380. //
  381. // Non-nil error means that request for the WebSocket upgrade is invalid or
  382. // malformed and usually connection should be closed.
  383. // Even when error is non-nil Upgrade will write appropriate response into
  384. // connection in compliance with RFC.
  385. func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
  386. // headerSeen constants helps to report whether or not some header was seen
  387. // during reading request bytes.
  388. const (
  389. headerSeenHost = 1 << iota
  390. headerSeenUpgrade
  391. headerSeenConnection
  392. headerSeenSecVersion
  393. headerSeenSecKey
  394. // headerSeenAll is the value that we expect to receive at the end of
  395. // headers read/parse loop.
  396. headerSeenAll = 0 |
  397. headerSeenHost |
  398. headerSeenUpgrade |
  399. headerSeenConnection |
  400. headerSeenSecVersion |
  401. headerSeenSecKey
  402. )
  403. // Prepare I/O buffers.
  404. // TODO(gobwas): make it configurable.
  405. br := pbufio.GetReader(conn,
  406. nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
  407. )
  408. bw := pbufio.GetWriter(conn,
  409. nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
  410. )
  411. defer func() {
  412. pbufio.PutReader(br)
  413. pbufio.PutWriter(bw)
  414. }()
  415. // Read HTTP request line like "GET /ws HTTP/1.1".
  416. rl, err := readLine(br)
  417. if err != nil {
  418. return
  419. }
  420. // Parse request line data like HTTP version, uri and method.
  421. req, err := httpParseRequestLine(rl)
  422. if err != nil {
  423. return
  424. }
  425. // Prepare stack-based handshake header list.
  426. header := handshakeHeader{
  427. 0: u.Header,
  428. }
  429. // Parse and check HTTP request.
  430. // As RFC6455 says:
  431. // The client's opening handshake consists of the following parts. If the
  432. // server, while reading the handshake, finds that the client did not
  433. // send a handshake that matches the description below (note that as per
  434. // [RFC2616], the order of the header fields is not important), including
  435. // but not limited to any violations of the ABNF grammar specified for
  436. // the components of the handshake, the server MUST stop processing the
  437. // client's handshake and return an HTTP response with an appropriate
  438. // error code (such as 400 Bad Request).
  439. //
  440. // See https://tools.ietf.org/html/rfc6455#section-4.2.1
  441. // An HTTP/1.1 or higher GET request, including a "Request-URI".
  442. //
  443. // Even if RFC says "1.1 or higher" without mentioning the part of the
  444. // version, we apply it only to minor part.
  445. switch {
  446. case req.major != 1 || req.minor < 1:
  447. // Abort processing the whole request because we do not even know how
  448. // to actually parse it.
  449. err = ErrHandshakeBadProtocol
  450. case btsToString(req.method) != http.MethodGet:
  451. err = ErrHandshakeBadMethod
  452. default:
  453. if onRequest := u.OnRequest; onRequest != nil {
  454. err = onRequest(req.uri)
  455. }
  456. }
  457. // Start headers read/parse loop.
  458. var (
  459. // headerSeen reports which header was seen by setting corresponding
  460. // bit on.
  461. headerSeen byte
  462. nonce = make([]byte, nonceSize)
  463. )
  464. for err == nil {
  465. line, e := readLine(br)
  466. if e != nil {
  467. return hs, e
  468. }
  469. if len(line) == 0 {
  470. // Blank line, no more lines to read.
  471. break
  472. }
  473. k, v, ok := httpParseHeaderLine(line)
  474. if !ok {
  475. err = ErrMalformedRequest
  476. break
  477. }
  478. switch btsToString(k) {
  479. case headerHostCanonical:
  480. headerSeen |= headerSeenHost
  481. if onHost := u.OnHost; onHost != nil {
  482. err = onHost(v)
  483. }
  484. case headerUpgradeCanonical:
  485. headerSeen |= headerSeenUpgrade
  486. if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
  487. err = ErrHandshakeBadUpgrade
  488. }
  489. case headerConnectionCanonical:
  490. headerSeen |= headerSeenConnection
  491. if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
  492. err = ErrHandshakeBadConnection
  493. }
  494. case headerSecVersionCanonical:
  495. headerSeen |= headerSeenSecVersion
  496. if !bytes.Equal(v, specHeaderValueSecVersion) {
  497. err = ErrHandshakeUpgradeRequired
  498. }
  499. case headerSecKeyCanonical:
  500. headerSeen |= headerSeenSecKey
  501. if len(v) != nonceSize {
  502. err = ErrHandshakeBadSecKey
  503. } else {
  504. copy(nonce[:], v)
  505. }
  506. case headerSecProtocolCanonical:
  507. if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
  508. var ok bool
  509. if custom != nil {
  510. hs.Protocol, ok = custom(v)
  511. } else {
  512. hs.Protocol, ok = btsSelectProtocol(v, check)
  513. }
  514. if !ok {
  515. err = ErrMalformedRequest
  516. }
  517. }
  518. case headerSecExtensionsCanonical:
  519. if f := u.Negotiate; err == nil && f != nil {
  520. hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f)
  521. }
  522. // DEPRECATED path.
  523. if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) {
  524. var ok bool
  525. if custom != nil {
  526. hs.Extensions, ok = custom(v, hs.Extensions)
  527. } else {
  528. hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
  529. }
  530. if !ok {
  531. err = ErrMalformedRequest
  532. }
  533. }
  534. default:
  535. if onHeader := u.OnHeader; onHeader != nil {
  536. err = onHeader(k, v)
  537. }
  538. }
  539. }
  540. switch {
  541. case err == nil && headerSeen != headerSeenAll:
  542. switch {
  543. case headerSeen&headerSeenHost == 0:
  544. // As RFC2616 says:
  545. // A client MUST include a Host header field in all HTTP/1.1
  546. // request messages. If the requested URI does not include an
  547. // Internet host name for the service being requested, then the
  548. // Host header field MUST be given with an empty value. An
  549. // HTTP/1.1 proxy MUST ensure that any request message it
  550. // forwards does contain an appropriate Host header field that
  551. // identifies the service being requested by the proxy. All
  552. // Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
  553. // Request) status code to any HTTP/1.1 request message which
  554. // lacks a Host header field.
  555. err = ErrHandshakeBadHost
  556. case headerSeen&headerSeenUpgrade == 0:
  557. err = ErrHandshakeBadUpgrade
  558. case headerSeen&headerSeenConnection == 0:
  559. err = ErrHandshakeBadConnection
  560. case headerSeen&headerSeenSecVersion == 0:
  561. // In case of empty or not present version we do not send 426 status,
  562. // because it does not meet the ABNF rules of RFC6455:
  563. //
  564. // version = DIGIT | (NZDIGIT DIGIT) |
  565. // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
  566. // ; Limited to 0-255 range, with no leading zeros
  567. //
  568. // That is, if version is really invalid – we sent 426 status as above, if it
  569. // not present – it is 400.
  570. err = ErrHandshakeBadSecVersion
  571. case headerSeen&headerSeenSecKey == 0:
  572. err = ErrHandshakeBadSecKey
  573. default:
  574. panic("unknown headers state")
  575. }
  576. case err == nil && u.OnBeforeUpgrade != nil:
  577. header[1], err = u.OnBeforeUpgrade()
  578. }
  579. if err != nil {
  580. var code int
  581. if rej, ok := err.(*rejectConnectionError); ok {
  582. code = rej.code
  583. header[1] = rej.header
  584. }
  585. if code == 0 {
  586. code = http.StatusInternalServerError
  587. }
  588. httpWriteResponseError(bw, err, code, header.WriteTo)
  589. // Do not store Flush() error to not override already existing one.
  590. bw.Flush()
  591. return
  592. }
  593. httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
  594. err = bw.Flush()
  595. return
  596. }
  597. type handshakeHeader [2]HandshakeHeader
  598. func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
  599. for i := 0; i < len(hs) && err == nil; i++ {
  600. if h := hs[i]; h != nil {
  601. var m int64
  602. m, err = h.WriteTo(w)
  603. n += m
  604. }
  605. }
  606. return n, err
  607. }