http.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "net/http"
  7. "net/textproto"
  8. "net/url"
  9. "strconv"
  10. "github.com/gobwas/httphead"
  11. )
  12. const (
  13. crlf = "\r\n"
  14. colonAndSpace = ": "
  15. commaAndSpace = ", "
  16. )
  17. const (
  18. textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
  19. )
  20. var (
  21. textHeadBadRequest = statusText(http.StatusBadRequest)
  22. textHeadInternalServerError = statusText(http.StatusInternalServerError)
  23. textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
  24. textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
  25. textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
  26. textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
  27. textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
  28. textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
  29. textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
  30. textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
  31. textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
  32. textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
  33. )
  34. var (
  35. headerHost = "Host"
  36. headerUpgrade = "Upgrade"
  37. headerConnection = "Connection"
  38. headerSecVersion = "Sec-WebSocket-Version"
  39. headerSecProtocol = "Sec-WebSocket-Protocol"
  40. headerSecExtensions = "Sec-WebSocket-Extensions"
  41. headerSecKey = "Sec-WebSocket-Key"
  42. headerSecAccept = "Sec-WebSocket-Accept"
  43. headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
  44. headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
  45. headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
  46. headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
  47. headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
  48. headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
  49. headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
  50. headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
  51. )
  52. var (
  53. specHeaderValueUpgrade = []byte("websocket")
  54. specHeaderValueConnection = []byte("Upgrade")
  55. specHeaderValueConnectionLower = []byte("upgrade")
  56. specHeaderValueSecVersion = []byte("13")
  57. )
  58. var (
  59. httpVersion1_0 = []byte("HTTP/1.0")
  60. httpVersion1_1 = []byte("HTTP/1.1")
  61. httpVersionPrefix = []byte("HTTP/")
  62. )
  63. type httpRequestLine struct {
  64. method, uri []byte
  65. major, minor int
  66. }
  67. type httpResponseLine struct {
  68. major, minor int
  69. status int
  70. reason []byte
  71. }
  72. // httpParseRequestLine parses http request line like "GET / HTTP/1.0".
  73. func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
  74. var proto []byte
  75. req.method, req.uri, proto = bsplit3(line, ' ')
  76. var ok bool
  77. req.major, req.minor, ok = httpParseVersion(proto)
  78. if !ok {
  79. err = ErrMalformedRequest
  80. return
  81. }
  82. return
  83. }
  84. func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
  85. var (
  86. proto []byte
  87. status []byte
  88. )
  89. proto, status, resp.reason = bsplit3(line, ' ')
  90. var ok bool
  91. resp.major, resp.minor, ok = httpParseVersion(proto)
  92. if !ok {
  93. return resp, ErrMalformedResponse
  94. }
  95. var convErr error
  96. resp.status, convErr = asciiToInt(status)
  97. if convErr != nil {
  98. return resp, ErrMalformedResponse
  99. }
  100. return resp, nil
  101. }
  102. // httpParseVersion parses major and minor version of HTTP protocol. It returns
  103. // parsed values and true if parse is ok.
  104. func httpParseVersion(bts []byte) (major, minor int, ok bool) {
  105. switch {
  106. case bytes.Equal(bts, httpVersion1_0):
  107. return 1, 0, true
  108. case bytes.Equal(bts, httpVersion1_1):
  109. return 1, 1, true
  110. case len(bts) < 8:
  111. return
  112. case !bytes.Equal(bts[:5], httpVersionPrefix):
  113. return
  114. }
  115. bts = bts[5:]
  116. dot := bytes.IndexByte(bts, '.')
  117. if dot == -1 {
  118. return
  119. }
  120. var err error
  121. major, err = asciiToInt(bts[:dot])
  122. if err != nil {
  123. return
  124. }
  125. minor, err = asciiToInt(bts[dot+1:])
  126. if err != nil {
  127. return
  128. }
  129. return major, minor, true
  130. }
  131. // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
  132. // values and true if parse is ok.
  133. func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
  134. colon := bytes.IndexByte(line, ':')
  135. if colon == -1 {
  136. return
  137. }
  138. k = btrim(line[:colon])
  139. // TODO(gobwas): maybe use just lower here?
  140. canonicalizeHeaderKey(k)
  141. v = btrim(line[colon+1:])
  142. return k, v, true
  143. }
  144. // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
  145. // that key is already canonical. This helps to increase performance.
  146. func httpGetHeader(h http.Header, key string) string {
  147. if h == nil {
  148. return ""
  149. }
  150. v := h[key]
  151. if len(v) == 0 {
  152. return ""
  153. }
  154. return v[0]
  155. }
  156. // The request MAY include a header field with the name
  157. // |Sec-WebSocket-Protocol|. If present, this value indicates one or more
  158. // comma-separated subprotocol the client wishes to speak, ordered by
  159. // preference. The elements that comprise this value MUST be non-empty strings
  160. // with characters in the range U+0021 to U+007E not including separator
  161. // characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
  162. // for the value of this header field is 1#token, where the definitions of
  163. // constructs and rules are as given in [RFC2616].
  164. func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
  165. ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
  166. if check(btsToString(v)) {
  167. ret = string(v)
  168. return false
  169. }
  170. return true
  171. })
  172. return
  173. }
  174. func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
  175. var selected []byte
  176. ok = httphead.ScanTokens(h, func(v []byte) bool {
  177. if check(v) {
  178. selected = v
  179. return false
  180. }
  181. return true
  182. })
  183. if ok && selected != nil {
  184. return string(selected), true
  185. }
  186. return
  187. }
  188. func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
  189. s := httphead.OptionSelector{
  190. Flags: httphead.SelectCopy,
  191. Check: check,
  192. }
  193. return s.Select(h, selected)
  194. }
  195. func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
  196. if in.Size() == 0 {
  197. return dest, nil
  198. }
  199. opt, err := f(in)
  200. if err != nil {
  201. return nil, err
  202. }
  203. if opt.Size() > 0 {
  204. dest = append(dest, opt)
  205. }
  206. return dest, nil
  207. }
  208. func negotiateExtensions(
  209. h []byte, dest []httphead.Option,
  210. f func(httphead.Option) (httphead.Option, error),
  211. ) (_ []httphead.Option, err error) {
  212. index := -1
  213. var current httphead.Option
  214. ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
  215. if i != index {
  216. dest, err = negotiateMaybe(current, dest, f)
  217. if err != nil {
  218. return httphead.ControlBreak
  219. }
  220. index = i
  221. current = httphead.Option{Name: name}
  222. }
  223. if attr != nil {
  224. current.Parameters.Set(attr, val)
  225. }
  226. return httphead.ControlContinue
  227. })
  228. if !ok {
  229. return nil, ErrMalformedRequest
  230. }
  231. return negotiateMaybe(current, dest, f)
  232. }
  233. func httpWriteHeader(bw *bufio.Writer, key, value string) {
  234. httpWriteHeaderKey(bw, key)
  235. bw.WriteString(value)
  236. bw.WriteString(crlf)
  237. }
  238. func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
  239. httpWriteHeaderKey(bw, key)
  240. bw.Write(value)
  241. bw.WriteString(crlf)
  242. }
  243. func httpWriteHeaderKey(bw *bufio.Writer, key string) {
  244. bw.WriteString(key)
  245. bw.WriteString(colonAndSpace)
  246. }
  247. func httpWriteUpgradeRequest(
  248. bw *bufio.Writer,
  249. u *url.URL,
  250. nonce []byte,
  251. protocols []string,
  252. extensions []httphead.Option,
  253. header HandshakeHeader,
  254. ) {
  255. bw.WriteString("GET ")
  256. bw.WriteString(u.RequestURI())
  257. bw.WriteString(" HTTP/1.1\r\n")
  258. httpWriteHeader(bw, headerHost, u.Host)
  259. httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
  260. httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
  261. httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
  262. // NOTE: write nonce bytes as a string to prevent heap allocation –
  263. // WriteString() copy given string into its inner buffer, unlike Write()
  264. // which may write p directly to the underlying io.Writer – which in turn
  265. // will lead to p escape.
  266. httpWriteHeader(bw, headerSecKey, btsToString(nonce))
  267. if len(protocols) > 0 {
  268. httpWriteHeaderKey(bw, headerSecProtocol)
  269. for i, p := range protocols {
  270. if i > 0 {
  271. bw.WriteString(commaAndSpace)
  272. }
  273. bw.WriteString(p)
  274. }
  275. bw.WriteString(crlf)
  276. }
  277. if len(extensions) > 0 {
  278. httpWriteHeaderKey(bw, headerSecExtensions)
  279. httphead.WriteOptions(bw, extensions)
  280. bw.WriteString(crlf)
  281. }
  282. if header != nil {
  283. header.WriteTo(bw)
  284. }
  285. bw.WriteString(crlf)
  286. }
  287. func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
  288. bw.WriteString(textHeadUpgrade)
  289. httpWriteHeaderKey(bw, headerSecAccept)
  290. writeAccept(bw, nonce)
  291. bw.WriteString(crlf)
  292. if hs.Protocol != "" {
  293. httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
  294. }
  295. if len(hs.Extensions) > 0 {
  296. httpWriteHeaderKey(bw, headerSecExtensions)
  297. httphead.WriteOptions(bw, hs.Extensions)
  298. bw.WriteString(crlf)
  299. }
  300. if header != nil {
  301. header(bw)
  302. }
  303. bw.WriteString(crlf)
  304. }
  305. func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
  306. switch code {
  307. case http.StatusBadRequest:
  308. bw.WriteString(textHeadBadRequest)
  309. case http.StatusInternalServerError:
  310. bw.WriteString(textHeadInternalServerError)
  311. case http.StatusUpgradeRequired:
  312. bw.WriteString(textHeadUpgradeRequired)
  313. default:
  314. writeStatusText(bw, code)
  315. }
  316. // Write custom headers.
  317. if header != nil {
  318. header(bw)
  319. }
  320. switch err {
  321. case ErrHandshakeBadProtocol:
  322. bw.WriteString(textTailErrHandshakeBadProtocol)
  323. case ErrHandshakeBadMethod:
  324. bw.WriteString(textTailErrHandshakeBadMethod)
  325. case ErrHandshakeBadHost:
  326. bw.WriteString(textTailErrHandshakeBadHost)
  327. case ErrHandshakeBadUpgrade:
  328. bw.WriteString(textTailErrHandshakeBadUpgrade)
  329. case ErrHandshakeBadConnection:
  330. bw.WriteString(textTailErrHandshakeBadConnection)
  331. case ErrHandshakeBadSecAccept:
  332. bw.WriteString(textTailErrHandshakeBadSecAccept)
  333. case ErrHandshakeBadSecKey:
  334. bw.WriteString(textTailErrHandshakeBadSecKey)
  335. case ErrHandshakeBadSecVersion:
  336. bw.WriteString(textTailErrHandshakeBadSecVersion)
  337. case ErrHandshakeUpgradeRequired:
  338. bw.WriteString(textTailErrUpgradeRequired)
  339. case nil:
  340. bw.WriteString(crlf)
  341. default:
  342. writeErrorText(bw, err)
  343. }
  344. }
  345. func writeStatusText(bw *bufio.Writer, code int) {
  346. bw.WriteString("HTTP/1.1 ")
  347. bw.WriteString(strconv.Itoa(code))
  348. bw.WriteByte(' ')
  349. bw.WriteString(http.StatusText(code))
  350. bw.WriteString(crlf)
  351. bw.WriteString("Content-Type: text/plain; charset=utf-8")
  352. bw.WriteString(crlf)
  353. }
  354. func writeErrorText(bw *bufio.Writer, err error) {
  355. body := err.Error()
  356. bw.WriteString("Content-Length: ")
  357. bw.WriteString(strconv.Itoa(len(body)))
  358. bw.WriteString(crlf)
  359. bw.WriteString(crlf)
  360. bw.WriteString(body)
  361. }
  362. // httpError is like the http.Error with WebSocket context exception.
  363. func httpError(w http.ResponseWriter, body string, code int) {
  364. w.Header().Set("Content-Type", "text/plain; charset=utf-8")
  365. w.Header().Set("Content-Length", strconv.Itoa(len(body)))
  366. w.WriteHeader(code)
  367. w.Write([]byte(body))
  368. }
  369. // statusText is a non-performant status text generator.
  370. // NOTE: Used only to generate constants.
  371. func statusText(code int) string {
  372. var buf bytes.Buffer
  373. bw := bufio.NewWriter(&buf)
  374. writeStatusText(bw, code)
  375. bw.Flush()
  376. return buf.String()
  377. }
  378. // errorText is a non-performant error text generator.
  379. // NOTE: Used only to generate constants.
  380. func errorText(err error) string {
  381. var buf bytes.Buffer
  382. bw := bufio.NewWriter(&buf)
  383. writeErrorText(bw, err)
  384. bw.Flush()
  385. return buf.String()
  386. }
  387. // HandshakeHeader is the interface that writes both upgrade request or
  388. // response headers into a given io.Writer.
  389. type HandshakeHeader interface {
  390. io.WriterTo
  391. }
  392. // HandshakeHeaderString is an adapter to allow the use of headers represented
  393. // by ordinary string as HandshakeHeader.
  394. type HandshakeHeaderString string
  395. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  396. func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
  397. n, err := io.WriteString(w, string(s))
  398. return int64(n), err
  399. }
  400. // HandshakeHeaderBytes is an adapter to allow the use of headers represented
  401. // by ordinary slice of bytes as HandshakeHeader.
  402. type HandshakeHeaderBytes []byte
  403. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  404. func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
  405. n, err := w.Write(b)
  406. return int64(n), err
  407. }
  408. // HandshakeHeaderFunc is an adapter to allow the use of headers represented by
  409. // ordinary function as HandshakeHeader.
  410. type HandshakeHeaderFunc func(io.Writer) (int64, error)
  411. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  412. func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
  413. return f(w)
  414. }
  415. // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
  416. // HandshakeHeader.
  417. type HandshakeHeaderHTTP http.Header
  418. // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
  419. func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
  420. wr := writer{w: w}
  421. err := http.Header(h).Write(&wr)
  422. return wr.n, err
  423. }
  424. type writer struct {
  425. n int64
  426. w io.Writer
  427. }
  428. func (w *writer) WriteString(s string) (int, error) {
  429. n, err := io.WriteString(w.w, s)
  430. w.n += int64(n)
  431. return n, err
  432. }
  433. func (w *writer) Write(p []byte) (int, error) {
  434. n, err := w.w.Write(p)
  435. w.n += int64(n)
  436. return n, err
  437. }